Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[v1.4.1] Java bug-fix cherry pick #14834

Merged
merged 3 commits into from
Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,23 @@ abstract class BaseModule {

/**
* Run prediction and collect the outputs.
* @param evalData
* @param evalData dataIter to do the Inference
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
* The concatenation process will be like
* {{{
* outputBatches = [
* [a1, a2, a3], // batch a
* [b1, b2, b3] // batch b
* ]
* result = [
* NDArray, // [a1, b1]
* NDArray, // [a2, b2]
* NDArray, // [a3, b3]
* ]
* }}}
* Where each element is concatenation of the outputs for all the mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
Expand All @@ -264,7 +276,8 @@ abstract class BaseModule {
s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
val oBT = outputBatches.transpose
val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
import org.apache.mxnet.io._

class ModuleSuite extends FunSuite with BeforeAndAfterAll {

class myModule(symbol : Symbol) extends Module (symbol) {
override def predictEveryBatch(evalData: DataIter,
numBatch: Int = 1, reset: Boolean = true):
IndexedSeq[IndexedSeq[NDArray]] = {
val data = IndexedSeq(
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 4))
)
List.fill(numBatch)(data).toIndexedSeq
}
}

test("predict") {
val sym = Symbol.Variable("data")
val mod = new myModule(sym)
val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
var output = mod.predict(dummyIter, 1)
require(output(0).shape == Shape(1, 10, 1))
require(output(1).shape == Shape(1, 10, 1))
require(output(2).shape == Shape(1, 10, 4))
output = mod.predict(dummyIter, 2)
require(output(0).shape == Shape(2, 10, 1))
require(output(1).shape == Shape(2, 10, 1))
require(output(2).shape == Shape(2, 10, 4))
}

test ("model dtype") {
val dType = DType.Float32
val dShape = Shape(3, 8, 7)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import java.security.MessageDigest
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* This object will generate the Scala documentation of the Scala/Java APIs
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase {

/**
* Main method used to generate code and write to files
* A hash check placed at the end to verify changes
* @param args Input args
*/
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
Expand All @@ -40,13 +44,25 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val finalHash = hashCollector.mkString("\n")
}

/**
* Generate MD5 result from an input string
* Encoded in UTF-8
* @param input The input string
* @return A MD5 value from the string
*/
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

/**
* Type-safe class body generation for NDArray/Symbol
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -57,11 +73,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"package org.apache.mxnet",
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"import org.apache.mxnet.annotation.Experimental",
generated)
}

/**
* Non Type-safe interface of Scala Symbol/NDArray
* It includes class definition : e.g class SymbolBase
* and function definitions : e.g def softmax(...)(...)(...) : NDArray
* Users can directly use the api by calling NDArray.<function_name>
* It support both positional input or Map input
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -85,34 +112,53 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
if (isSymbol) "SymbolBase" else "NDArrayBase",
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

def javaClassGen(filePath : String) : String = {
/**
* Type-safe interface of Java NDArray
* @param FILE_PATH File path write the file to
* @return MD5 String
*/
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
* Kill the capital lettered method such as Cast vs cast
* As it defined by default it deprecated
*/
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).map(absClassFunction => {
val (absFuncs, paramClassUncleaned) =
absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
* Kill the capital lettered method such as Cast vs cast
* As it defined by default it deprecated
*/
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
}).toSeq
}).toSeq.unzip
val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
writeFile(
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs, Some(paramClass))
}

/**
* Generate Scala docs from the function description
* @param func The function case class
* @param withParam Whether to generate param field
* @return A formatted string for the function description
*/
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
def fixDesc(desc: String): String = {
var curDesc = desc
Expand Down Expand Up @@ -146,7 +192,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}

def generateAPISignature(func: Func, isSymbol: Boolean): String = {
/**
* Generate the function interface
* e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
* @param func The function case class
* @param isSymbol Check if generate Symbol function, NDArray otherwise
* @param typeParameter Type param specifically used in Random Module
* @return Formatted string for the function
*/
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDef = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)
Expand All @@ -162,10 +216,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val returnType = func.returnType

s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : Func) : String = {
/**
* Generate Java function interface
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
Expand Down Expand Up @@ -204,54 +263,67 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
(s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| """.stripMargin,
s"""/**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin)
} else {
argDef += "out : NDArray"
s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
(s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin, "")
}
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {
/**
* Write the formatted string to file
* @param FILE_PATH Location of the file writes to
* @param packageDef Package definition
* @param className Class name
* @param imports Packages need to import
* @param absFuncs All formatted functions
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String],
paramClass: Option[Seq[String]] = None): String = {

val finalStr =
s"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
| * Licensed to the Apache Software Foundation (ASF) under one or more
| * contributor license agreements. See the NOTICE file distributed with
| * this work for additional information regarding copyright ownership.
| * The ASF licenses this file to You under the Apache License, Version 2.0
| * (the "License"); you may not use this file except in compliance with
| * the License. You may obtain a copy of the License at
| *
| * http://www.apache.org/licenses/LICENSE-2.0
| *
| * Unless required by applicable law or agreed to in writing, software
| * distributed under the License is distributed on an "AS IS" BASIS,
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
| * See the License for the specific language governing permissions and
| * limitations under the License.
| */
|
|$packageDef
|
|import org.apache.mxnet.annotation.Experimental
|$imports
|
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin
|}
|${paramClass.getOrElse(Seq()).mkString("\n")}
|""".stripMargin


val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
Expand Down