Skip to content

Commit 83af0a0

Browse files
authored
Local scoring (aka Sparkless) using Aardpfark (#41)
1 parent a5975a7 commit 83af0a0

File tree

8 files changed

+433
-7
lines changed

8 files changed

+433
-7
lines changed

build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ configure(allProjs) {
8888
commonsValidatorVersion = '1.6'
8989
commonsIOVersion = '2.6'
9090
scoveragePluginVersion = '1.3.1'
91+
hadrianVersion = '0.8.5'
92+
aardpfarkVersion = '0.1.0-SNAPSHOT'
9193

9294
mainClassName = 'com.salesforce.Main'
9395
}

features/src/main/scala/com/salesforce/op/utils/spark/RichRow.scala

+26-6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ object RichRow {
6969

7070
/**
7171
* Returns map feature by name
72+
*
7273
* @param fieldName name of map feature
7374
* @return feature value as instance of Map[String, Any]
7475
*/
@@ -88,19 +89,19 @@ object RichRow {
8889
/**
8990
* Returns the value of field named {fieldName}. If the value is null, None is returned.
9091
*/
91-
def getOption[T](fieldName: String): Option[T] = getOptionAny(fieldName) collect { case t: T @unchecked => t }
92+
def getOption[T](fieldName: String): Option[T] = getOptionAny(fieldName) collect { case t: T@unchecked => t }
9293

9394
/**
9495
* Returns the value at position i. If the value is null, None is returned.
9596
*/
96-
def getOption[T](i: Integer): Option[T] = getOptionAny(i) collect { case t: T @unchecked => t }
97+
def getOption[T](i: Integer): Option[T] = getOptionAny(i) collect { case t: T@unchecked => t }
9798

9899
/**
99100
* Returns the value of a given feature casted into the feature type
100101
*
101102
* @throws UnsupportedOperationException when schema is not defined.
102-
* @throws IllegalArgumentException when fieldName do not exist.
103-
* @throws ClassCastException when data type does not match.
103+
* @throws IllegalArgumentException when fieldName do not exist.
104+
* @throws ClassCastException when data type does not match.
104105
*/
105106
def getFeatureType[T <: FeatureType](f: FeatureLike[T])(implicit conv: FeatureTypeSparkConverter[T]): T =
106107
conv.fromSpark(getAny(f.name))
@@ -110,12 +111,31 @@ object RichRow {
110111
* weak type tag of features
111112
*
112113
* @throws UnsupportedOperationException when schema is not defined.
113-
* @throws IllegalArgumentException when fieldName do not exist.
114-
* @throws ClassCastException when data type does not match.
114+
* @throws IllegalArgumentException when fieldName do not exist.
115+
* @throws ClassCastException when data type does not match.
115116
*/
116117
def getFeatureType[T <: FeatureType](f: TransientFeature)(implicit conv: FeatureTypeSparkConverter[T]): T =
117118
conv.fromSpark(getAny(f.name))
118119

120+
/**
121+
* Converts row to a [[collection.immutable.Map]]
122+
*
123+
* @return a [[collection.immutable.Map]] with all row contents
124+
*/
125+
def toMutableMap: collection.mutable.Map[String, Any] = {
126+
val res = collection.mutable.Map.empty[String, Any]
127+
val fields = row.schema.fields
128+
for {i <- fields.indices} res.put(fields(i).name, row(i))
129+
res
130+
}
131+
132+
/**
133+
* Converts row to a [[collection.mutable.Map]]
134+
*
135+
* @return a [[collection.mutable.Map]] with all row contents
136+
*/
137+
def toMap: Map[String, Any] = toMutableMap.toMap
138+
119139
}
120140

121141
}

local/build.gradle

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
repositories {
2+
// TODO: remove once Aardpfark release is official
3+
maven { url 'https://jitpack.io' }
4+
}
5+
6+
dependencies {
7+
compile project(':core')
8+
testCompile project(':testkit')
9+
10+
// PFA serialization for Spark models
11+
// TODO: replace with official Aardpfark release when ready
12+
compile "com.github.relateiq:aardpfark:$aardpfarkVersion"
13+
14+
// Hadrian PFA runtime for JVM
15+
compileOnly "com.opendatagroup:hadrian:$hadrianVersion"
16+
testRuntime "com.opendatagroup:hadrian:$hadrianVersion"
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright (c) 2017, Salesforce.com, Inc.
3+
* All rights reserved.
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* * Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* * Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* * Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*/
30+
31+
package com.salesforce.op.local
32+
33+
import com.ibm.aardpfark.spark.ml.SparkSupport
34+
import com.opendatagroup.hadrian.jvmcompiler.PFAEngine
35+
import com.salesforce.op.OpWorkflowModel
36+
import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams
37+
import com.salesforce.op.stages.{OPStage, OpTransformer}
38+
import org.apache.spark.ml.SparkMLSharedParamConstants._
39+
import org.apache.spark.ml.Transformer
40+
import org.apache.spark.ml.linalg.Vector
41+
import org.apache.spark.ml.param.ParamMap
42+
import org.json4s._
43+
import org.json4s.native.JsonMethods._
44+
import org.json4s.native.Serialization
45+
46+
import scala.collection.mutable
47+
48+
/**
49+
* Enrichment for [[OpWorkflowModel]] to allow local scoring functionality
50+
*/
51+
trait OpWorkflowModelLocal {
52+
53+
/**
54+
* Enrichment for [[OpWorkflowModel]] to allow local scoring functionality
55+
*
56+
* @param model [[OpWorkflowModel]]
57+
*/
58+
implicit class RichOpWorkflowModel(model: OpWorkflowModel) {
59+
60+
private implicit val formats = DefaultFormats
61+
62+
/**
63+
* Internal PFA model representation
64+
*
65+
* @param inputs mode inputs mappings
66+
* @param output output mapping
67+
* @param engine PFA engine
68+
*/
69+
private case class PFAModel
70+
(
71+
inputs: Map[String, String],
72+
output: (String, String),
73+
engine: PFAEngine[AnyRef, AnyRef]
74+
)
75+
76+
/**
77+
* Internal OP model representation
78+
*
79+
* @param output output name
80+
* @param model model instance
81+
*/
82+
private case class OPModel(output: String, model: OPStage with OpTransformer)
83+
84+
/**
85+
* Prepares a score function for local scoring
86+
*
87+
* @return score function for local scoring
88+
*/
89+
def scoreFunction: ScoreFunction = {
90+
// Prepare the stages for scoring
91+
val stagesWithIndex = model.stages.zipWithIndex
92+
// Collect all OP stages
93+
val opStages = stagesWithIndex.collect { case (s: OpTransformer, i) => OPModel(s.getOutputFeatureName, s) -> i }
94+
// Collect all Spark wrapped stages
95+
val sparkStages = stagesWithIndex.filterNot(_._1.isInstanceOf[OpTransformer]).collect {
96+
case (s: OPStage with SparkWrapperParams[_], i) if s.getSparkMlStage().isDefined =>
97+
(s, s.getSparkMlStage().get.asInstanceOf[Transformer].copy(ParamMap.empty), i)
98+
}
99+
// Convert Spark wrapped stages into PFA models
100+
val pfaStages = sparkStages.map { case (opStage, sparkStage, i) => toPFAModel(opStage, sparkStage) -> i }
101+
// Combine all stages and apply the original order
102+
val allStages = (opStages ++ pfaStages).sortBy(_._2).map(_._1)
103+
val resultFeatures = model.getResultFeatures().map(_.name).toSet
104+
105+
// Score Function
106+
input: Map[String, Any] => {
107+
val inputMap = mutable.Map.empty ++= input
108+
val transformedRow = allStages.foldLeft(inputMap) {
109+
// For OP Models we simply call transform
110+
case (row, OPModel(output, stage)) =>
111+
row += output -> stage.transformKeyValue(row.apply)
112+
113+
// For PFA Models we execute PFA engine action with json in/out
114+
case (row, PFAModel(inputs, (out, outCol), engine)) =>
115+
val inJson = rowToJson(row, inputs)
116+
val engineIn = engine.jsonInput(inJson)
117+
val engineOut = engine.action(engineIn)
118+
val resMap = parse(engineOut.toString).extract[Map[String, Any]]
119+
row += out -> resMap(outCol)
120+
}
121+
transformedRow.filterKeys(resultFeatures.contains).toMap
122+
}
123+
}
124+
125+
/**
126+
* Convert Spark wrapped staged into PFA Models
127+
*/
128+
private def toPFAModel(opStage: OPStage with SparkWrapperParams[_], sparkStage: Transformer): PFAModel = {
129+
// Update input/output params for Spark stages to default ones
130+
val inParam = sparkStage.getParam(inputCol.name)
131+
val outParam = sparkStage.getParam(outputCol.name)
132+
val inputs = opStage.getInputFeatures().map(_.name).map {
133+
case n if sparkStage.get(inParam).contains(n) => n -> inputCol.name
134+
case n if sparkStage.get(outParam).contains(n) => n -> outputCol.name
135+
case n => n -> n
136+
}.toMap
137+
val output = opStage.getOutputFeatureName
138+
sparkStage.set(inParam, inputCol.name).set(outParam, outputCol.name)
139+
val pfaJson = SparkSupport.toPFA(sparkStage, pretty = true)
140+
val pfaEngine = PFAEngine.fromJson(pfaJson).head
141+
PFAModel(inputs, (output, outputCol.name), pfaEngine)
142+
}
143+
144+
/**
145+
* Convert row of Spark values into a json convertible Map
146+
* See [[FeatureTypeSparkConverter.toSpark]] for all possible values - we invert them here
147+
*/
148+
private def rowToJson(row: mutable.Map[String, Any], inputs: Map[String, String]): String = {
149+
val in = inputs.map { case (k, v) => (v, row.get(k)) }.mapValues {
150+
case Some(v: Vector) => v.toArray
151+
case Some(v: mutable.WrappedArray[_]) => v.toArray(v.elemTag)
152+
case Some(v: Map[_, _]) => v.mapValues {
153+
case v: mutable.WrappedArray[_] => v.toArray(v.elemTag)
154+
case x => x
155+
}
156+
case None | Some(null) => null
157+
case Some(v) => v
158+
}
159+
Serialization.write(in)
160+
}
161+
}
162+
163+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) 2017, Salesforce.com, Inc.
3+
* All rights reserved.
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* * Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* * Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* * Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*/
30+
31+
package com.salesforce.op.local
32+
33+
import com.salesforce.op.{OpParams, OpWorkflow}
34+
35+
36+
/**
37+
* A class for running TransmogrifAI Workflow without Spark.
38+
*
39+
* @param workflow the workflow that you want to run (Note: the workflow should have the resultFeatures set)
40+
*/
41+
class OpWorkflowRunnerLocal(val workflow: OpWorkflow) {
42+
43+
/**
44+
* Load the model & prepare a score function for local scoring
45+
*
46+
* Note: since we use Spark native [[org.apache.spark.ml.util.MLWriter]] interface
47+
* to load stages the Spark session is being created internally. So if you would not like
48+
* to have an open SparkSession please make sure to stop it after creating the score function:
49+
*
50+
* val scoreFunction = new OpWorkflowRunnerLocal(workflow).score(params)
51+
* // stop the session after creating the scoreFunction if needed
52+
* SparkSession.builder().getOrCreate().stop()
53+
*
54+
* @param params params to use during scoring
55+
* @return score function for local scoring
56+
*/
57+
def score(params: OpParams): ScoreFunction = {
58+
require(params.modelLocation.isDefined, "Model location must be set in params")
59+
val model = workflow.loadModel(params.modelLocation.get)
60+
model.scoreFunction
61+
}
62+
63+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) 2017, Salesforce.com, Inc.
3+
* All rights reserved.
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* * Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* * Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* * Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*/
30+
31+
package com.salesforce.op
32+
33+
34+
package object local extends OpWorkflowModelLocal {
35+
36+
/**
37+
* Score function for local scoring: raw record => transformed record
38+
*/
39+
type ScoreFunction = Map[String, Any] => Map[String, Any]
40+
41+
}

0 commit comments

Comments
 (0)