Skip to content

Alternative aggregate functions to calculate histogram values. #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count}
import com.amazon.deequ.analyzers.runners.{IllegalAnalyzerParameterException, MetricCalculationException}
import com.amazon.deequ.metrics.{Distribution, DistributionValue, HistogramMetric}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.functions.{col, sum}
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row}

import scala.util.{Failure, Try}

/**
* Histogram is the summary of values in a column of a DataFrame. Groups the given column's values,
* and calculates the number of rows with that specific value and the fraction of this value.
* and calculates either number of rows or with that specific value and the fraction of this value or
* sum of values in other column.
*
* @param column Column to do histogram analysis on
* @param binningUdf Optional binning function to run before grouping to re-categorize the
Expand All @@ -37,13 +40,15 @@ import scala.util.{Failure, Try}
* maxBins sets the N.
* This limit does not affect what is being returned as number of bins. It
* always returns the dictinct value count.
* @param aggregateFunction function that implements aggregation logic.
*/
case class Histogram(
column: String,
binningUdf: Option[UserDefinedFunction] = None,
maxDetailBins: Integer = Histogram.MaximumAllowedDetailBins,
where: Option[String] = None,
computeFrequenciesAsRatio: Boolean = true)
computeFrequenciesAsRatio: Boolean = true,
aggregateFunction: AggregateFunction = Count)
extends Analyzer[FrequenciesAndNumRows, HistogramMetric]
with FilterableAnalyzer {

Expand All @@ -58,19 +63,15 @@ case class Histogram(

// TODO figure out a way to pass this in if its known before hand
val totalCount = if (computeFrequenciesAsRatio) {
data.count()
aggregateFunction.total(data)
} else {
1
}

val frequencies = data
val df = data
.transform(filterOptional(where))
.transform(binOptional(binningUdf))
.select(col(column).cast(StringType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.count()
.withColumnRenamed("count", Analyzers.COUNT_COL)
val frequencies = query(df)

Some(FrequenciesAndNumRows(frequencies, totalCount))
}
Expand Down Expand Up @@ -125,11 +126,67 @@ case class Histogram(
case _ => data
}
}

private def query(data: DataFrame): DataFrame = {
aggregateFunction.query(this.column, data)
}
}

object Histogram {
val NullFieldReplacement = "NullValue"
val MaximumAllowedDetailBins = 1000
val count_function = "count"
val sum_function = "sum"

sealed trait AggregateFunction {
def query(column: String, data: DataFrame): DataFrame

def total(data: DataFrame): Long

def aggregateColumn(): Option[String]

def function(): String
}

case object Count extends AggregateFunction {
override def query(column: String, data: DataFrame): DataFrame = {
data
.select(col(column).cast(StringType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.count()
.withColumnRenamed("count", Analyzers.COUNT_COL)
}

override def aggregateColumn(): Option[String] = None

override def function(): String = count_function

override def total(data: DataFrame): Long = {
data.count()
}
}

case class Sum(aggColumn: String) extends AggregateFunction {
override def query(column: String, data: DataFrame): DataFrame = {
data
.select(col(column).cast(StringType), col(aggColumn).cast(LongType))
.na.fill(Histogram.NullFieldReplacement)
.groupBy(column)
.sum(aggColumn)
.withColumnRenamed("count", Analyzers.COUNT_COL)
}

override def total(data: DataFrame): Long = {
data.groupBy().sum(aggColumn).first().getLong(0)
}

override def aggregateColumn(): Option[String] = {
Some(aggColumn)
}

override def function(): String = sum_function
}
}

object OrderByAbsoluteCount extends Ordering[Row] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import scala.collection._
import scala.collection.JavaConverters._
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}
import JsonSerializationConstants._
import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count => HistogramCount, Sum => HistogramSum}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.expr

Expand Down Expand Up @@ -302,6 +303,12 @@ private[deequ] object AnalyzerSerializer
result.addProperty(ANALYZER_NAME_FIELD, "Histogram")
result.addProperty(COLUMN_FIELD, histogram.column)
result.addProperty("maxDetailBins", histogram.maxDetailBins)
// Count is initial and default implementation for Histogram
// We don't include fields below in json to preserve json backward compatibility.
if (histogram.aggregateFunction != Histogram.Count) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here that:

  • the reason we are excluding Count since it is the current default
  • It was the only previously supported function in the past, and we exclude it for backwards compatibility

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add comments.

result.addProperty("aggregateFunction", histogram.aggregateFunction.function())
result.addProperty("aggregateColumn", histogram.aggregateFunction.aggregateColumn().get)
}

case _ : Histogram =>
throw new IllegalArgumentException("Unable to serialize Histogram with binningUdf!")
Expand Down Expand Up @@ -436,7 +443,10 @@ private[deequ] object AnalyzerDeserializer
Histogram(
json.get(COLUMN_FIELD).getAsString,
None,
json.get("maxDetailBins").getAsInt)
json.get("maxDetailBins").getAsInt,
aggregateFunction = createAggregateFunction(
getOptionalStringParam(json, "aggregateFunction").getOrElse(Histogram.count_function),
getOptionalStringParam(json, "aggregateColumn").getOrElse("")))

case "DataType" =>
DataType(
Expand Down Expand Up @@ -489,12 +499,24 @@ private[deequ] object AnalyzerDeserializer
}

private[this] def getOptionalWhereParam(jsonObject: JsonObject): Option[String] = {
if (jsonObject.has(WHERE_FIELD)) {
Option(jsonObject.get(WHERE_FIELD).getAsString)
getOptionalStringParam(jsonObject, WHERE_FIELD)
}

private[this] def getOptionalStringParam(jsonObject: JsonObject, field: String): Option[String] = {
if (jsonObject.has(field)) {
Option(jsonObject.get(field).getAsString)
} else {
None
}
}

private[this] def createAggregateFunction(function: String, aggregateColumn: String): AggregateFunction = {
function match {
case Histogram.count_function => HistogramCount
case Histogram.sum_function => HistogramSum(aggregateColumn)
case _ => throw new IllegalArgumentException("Wrong aggregate function name: " + function)
}
}
}

private[deequ] object MetricSerializer extends JsonSerializer[Metric[_]] {
Expand Down
20 changes: 20 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with
}
}

"compute correct sum metrics " in withSparkSession { sparkSession =>
val dfFull = getDateDf(sparkSession)
val histogram = Histogram("product", aggregateFunction = Histogram.Sum("units")).calculate(dfFull)
assert(histogram.value.isSuccess)

histogram.value.get match {
case hv =>
assert(hv.numberOfBins == 3)
assert(hv.values.size == 3)
assert(hv.values.keys == Set("Furniture", "Cosmetics", "Electronics"))
assert(hv("Furniture").absolute == 55)
assert(hv("Furniture").ratio == 55.0 / (55 + 20 + 60))
assert(hv("Cosmetics").absolute == 20)
assert(hv("Cosmetics").ratio == 20.0 / (55 + 20 + 60))
assert(hv("Electronics").absolute == 60)
assert(hv("Electronics").ratio == 60.0 / (55 + 20 + 60))

}
}

"compute correct metrics on numeric values" in withSparkSession { sparkSession =>
val dfFull = getDfWithNumericValues(sparkSession)
val histogram = Histogram("att2").calculate(dfFull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ class AnalyzerContextTest extends AnyWordSpec
|{"entity":"Column","instance":"item","name":"Distinctness","value":1.0},
|{"entity":"Column","instance":"att1","name":"Completeness","value":1.0},
|{"entity":"Multicolumn","instance":"att1,att2","name":"Uniqueness","value":0.25},
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0},
Comment on lines +89 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required? It is not clear from the changes that the order of the fields in the serialization process is changing. If it is not, can we remove this diff to make the PR clearer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is required. This is the reason a test was failing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the information. Is it possible to understand why it is required? The reason I ask is because of backwards compatibility. Will a serialized object that has the older order fail to get deserialized because of this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to dig deeper to explain why it happens. My guess is that there is some kind of HashMap, HashSet that changes order of response.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's what happens. Below are screenshots from 2 different versions (breakpoint is on this line

)

As you can see after map union it gives two different results. And code here (

) iterates through the HashMap.

Master branch
master
feature/histogram-aggregate branch
feature:histogram-aggregate

|{"entity":"Column","instance":"att1","name":"Histogram.bins","value":2.0},
|{"entity":"Column","instance":"att1","name":"Histogram.abs.a","value":3.0},
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.a","value":0.75},
|{"entity":"Column","instance":"att1","name":"Histogram.abs.b","value":1.0},
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25},
|{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0},
|{"entity":"Dataset","instance":"*","name":"Size","value":4.0}
|{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25}
|]"""
.stripMargin.replaceAll("\n", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,112 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers {
assertCorrectlyConvertsAnalysisResults(Seq(result))
}

val histogramSumJson =
"""[
| {
| "resultKey": {
| "dataSetDate": 0,
| "tags": {}
| },
| "analyzerContext": {
| "metricMap": [
| {
| "analyzer": {
| "analyzerName": "Histogram",
| "column": "columnA",
| "maxDetailBins": 1000,
| "aggregateFunction": "sum",
| "aggregateColumn": "columnB"
| },
| "metric": {
| "metricName": "HistogramMetric",
| "column": "columnA",
| "numberOfBins": 10,
| "value": {
| "numberOfBins": 10,
| "values": {
| "some": {
| "absolute": 10,
| "ratio": 0.5
| }
| }
| }
| }
| }
| ]
| }
| }
|]""".stripMargin
val histogramCountJson =
"""[
| {
| "resultKey": {
| "dataSetDate": 0,
| "tags": {}
| },
| "analyzerContext": {
| "metricMap": [
| {
| "analyzer": {
| "analyzerName": "Histogram",
| "column": "columnA",
| "maxDetailBins": 1000
| },
| "metric": {
| "metricName": "HistogramMetric",
| "column": "columnA",
| "numberOfBins": 10,
| "value": {
| "numberOfBins": 10,
| "values": {
| "some": {
| "absolute": 10,
| "ratio": 0.5
| }
| }
| }
| }
| }
| ]
| }
| }
|]""".stripMargin

"Histogram serialization" should "be backward compatible for count" in {
val expected = histogramCountJson
val analyzer = Histogram("columnA")
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val result = new AnalysisResult(ResultKey(0), context)
assert(serialize(Seq(result)) == expected)
}

"Histogram serialization" should "properly serialize sum" in {
val expected = histogramSumJson
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val result = new AnalysisResult(ResultKey(0), context)
assert(serialize(Seq(result)) == expected)
}

"Histogram deserialization" should "be backward compatible for count" in {
val analyzer = Histogram("columnA")
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val expected = new AnalysisResult(ResultKey(0), context)
assert(deserialize(histogramCountJson) == List(expected))
}

"Histogram deserialization" should "properly deserialize sum" in {
val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB"))
val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10)))
val context = AnalyzerContext(Map(analyzer -> metric))
val expected = new AnalysisResult(ResultKey(0), context)
assert(deserialize(histogramSumJson) == List(expected))
}


def assertCorrectlyConvertsAnalysisResults(
analysisResults: Seq[AnalysisResult],
shouldFail: Boolean = false)
Expand Down