Skip to content

Commit a77ee24

Browse files
committed
working tests
1 parent 2fa86a4 commit a77ee24

File tree

6 files changed

+234
-81
lines changed

6 files changed

+234
-81
lines changed

api/py/ai/chronon/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def join_part_output_table_name(join, jp, full_name: bool = False):
262262
def partOutputTable(jp: JoinPart): String = (Seq(join.metaData.outputTable) ++ Option(jp.prefix) :+
263263
jp.groupBy.metaData.cleanName).mkString("_")
264264
"""
265+
print(join)
265266
if not join.metaData.name and isinstance(join, api.Join):
266267
__set_name(join, api.Join, "joins")
267268
return "_".join(

orchestration/src/main/scala/ai/chronon/orchestration/RepoIndex.scala

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import ai.chronon.orchestration.RepoIndex._
44
import ai.chronon.orchestration.RepoTypes._
55
import ai.chronon.orchestration.utils.CollectionExtensions.IteratorExtensions
66
import ai.chronon.orchestration.utils.SequenceMap
7+
import ai.chronon.orchestration.utils.StringExtensions.StringOps
78
import org.apache.logging.log4j.scala.Logging
89

910
import scala.collection.mutable
1011

11-
class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
12+
class RepoIndex[T >: Null](proc: ConfProcessor[T]) extends Logging {
1213

1314
// first pass updates
1415
private val branchToFileHash: TriMap[Branch, Name, FileHash] = mutable.Map.empty
@@ -19,30 +20,14 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
1920
private val versionSequencer: SequenceMap[Name, GlobalHash] = new SequenceMap[Name, GlobalHash]
2021

2122
def addNodes(fileHashes: mutable.Map[Name, FileHash],
22-
nodes: Seq[T],
23+
newNodes: Seq[T],
2324
branch: Branch,
2425
dryRun: Boolean = true): Seq[VersionUpdate] = {
2526

26-
val newContents = nodes.map { node =>
27-
val data = proc.toLocalData(node)
28-
val nodeContent = NodeContent(data, node)
29-
30-
require(data.fileHash == fileHashes(data.name), s"File hash mismatch for ${data.name}")
31-
32-
data.name -> (data.fileHash -> nodeContent)
33-
34-
}.toMap
35-
36-
def getContents(name: Name, fileHash: FileHash): NodeContent[T] = {
37-
38-
val incomingContents = newContents.get(name).map(_._2)
39-
40-
lazy val existingContents = fileHashToContent
41-
.get(name)
42-
.flatMap(_.get(fileHash))
43-
44-
incomingContents.orElse(existingContents).get
45-
}
27+
val newContents = buildContentMap(proc, newNodes, fileHashes)
28+
val enrichedFileHashes = newContents.map {
29+
case (name, content) => name -> content.localData.fileHash
30+
} ++ fileHashes
4631

4732
val globalHashes = mutable.Map.empty[Name, GlobalHash]
4833

@@ -51,8 +36,27 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
5136

5237
if (globalHashes.contains(name)) return globalHashes(name)
5338

54-
val fileHash = fileHashes(name)
55-
val content = getContents(name, fileHash)
39+
val fileHash = enrichedFileHashes.get(name) match {
40+
case Some(hash) => hash
41+
42+
// this could be an artifact related to unchanged files on the branch
43+
// we reach out to content index
44+
// artifacts are just names with no content - so there should be just one entry
45+
case None =>
46+
val hashToContent = fileHashToContent(name)
47+
48+
require(hashToContent.nonEmpty, s"Expected 1 entry for artifact $name, found none")
49+
require(hashToContent.size == 1, s"Expected 1 entry for artifact $name, found ${hashToContent.size}")
50+
51+
hashToContent.head._1
52+
}
53+
54+
val content = if (newContents.contains(name)) {
55+
newContents(name)
56+
} else {
57+
// fetch
58+
fileHashToContent(name)(fileHash)
59+
}
5660

5761
val localHash = content.localData.localHash
5862
val parents = content.localData.inputs
@@ -70,7 +74,7 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
7074

7175
logger.info(s"codeString: $codeString")
7276

73-
val globalHash = GlobalHash(codeString.hashCode().toHexString)
77+
val globalHash = GlobalHash(codeString.md5)
7478

7579
globalHashes.update(name, globalHash)
7680
globalHash
@@ -90,12 +94,12 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
9094
val mainVersions = branchVersionIndex.getOrElse(Branch.main, mutable.Map.empty)
9195

9296
val versionUpdates = VersionUpdate.join(newVersions, existingVersions, mainVersions)
93-
VersionUpdate.print(versionUpdates)
9497

9598
if (!dryRun) {
9699

100+
logger.info("Not a dry run! Inserting new nodes into the index into branch: " + branch.name)
97101
newContents.foreach {
98-
case (name, (fileHash, content)) => update(fileHashToContent, name, fileHash, content)
102+
case (name, content) => update(fileHashToContent, name, content.localData.fileHash, content)
99103
}
100104

101105
val newVersions = globalHashes.map {
@@ -105,7 +109,7 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
105109
name -> version
106110
}
107111

108-
branchToFileHash.update(branch, fileHashes)
112+
branchToFileHash.update(branch, enrichedFileHashes)
109113
branchVersionIndex.update(branch, newVersions)
110114

111115
}
@@ -141,7 +145,7 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
141145
private def pruneContents(): Unit = {
142146

143147
// collect unique hashes per name from every branch
144-
val validHashes: mutable.Map[Name, mutable.HashSet[FileHash]] = innerValues(branchToFileHash)
148+
val validHashes: mutable.Map[Name, mutable.HashSet[FileHash]] = innerKeyToValueSet(branchToFileHash)
145149

146150
fileHashToContent.retain {
147151
case (name, fileHashMap) =>
@@ -173,14 +177,12 @@ class RepoIndex[T](proc: ConfProcessor[T]) extends Logging {
173177

174178
object RepoIndex {
175179

176-
private case class NodeContent[T](localData: LocalData, conf: T)
177-
178180
private type TriMap[K1, K2, V] = mutable.Map[K1, mutable.Map[K2, V]]
179181

180182
private def update[K1, K2, V](map: TriMap[K1, K2, V], k1: K1, k2: K2, v: V): Unit =
181183
map.getOrElseUpdate(k1, mutable.Map.empty).update(k2, v)
182184

183-
private def innerValues[K1, K2, V](map: TriMap[K1, K2, V]): mutable.Map[K2, mutable.HashSet[V]] = {
185+
private def innerKeyToValueSet[K1, K2, V](map: TriMap[K1, K2, V]): mutable.Map[K2, mutable.HashSet[V]] = {
184186
val result = mutable.Map.empty[K2, mutable.HashSet[V]]
185187
map.values.foreach { innerMap =>
186188
innerMap.foreach {
@@ -191,4 +193,54 @@ object RepoIndex {
191193
result
192194
}
193195

196+
/**
197+
* Takes data from repo parser and builds a local index for the repo parser
198+
* We treat inputs and outputs that are not present in FileHashes as artifacts
199+
* For these artifacts we create additional entries in the result
200+
*/
201+
def buildContentMap[T >: Null](proc: ConfProcessor[T],
202+
nodes: Seq[T],
203+
fileHashes: mutable.Map[Name, FileHash]): mutable.Map[Name, NodeContent[T]] = {
204+
205+
val contentMap = mutable.Map.empty[Name, NodeContent[T]]
206+
207+
// first pass - update non-artifact contents
208+
for (
209+
node <- nodes;
210+
nodeContent <- proc.nodeContents(node)
211+
) {
212+
213+
val name = nodeContent.localData.name
214+
contentMap.update(name, nodeContent)
215+
216+
def updateContents(artifactName: Name, isOutput: Boolean): Unit = {
217+
218+
// artifacts are not present in file hashes
219+
if (fileHashes.contains(artifactName)) return
220+
221+
val existingParents = if (contentMap.contains(artifactName)) {
222+
contentMap(artifactName).localData.inputs
223+
} else {
224+
Seq.empty
225+
}
226+
227+
val newParents = if (isOutput) Seq(name) else Seq.empty
228+
229+
val parents = (existingParents ++ newParents).distinct
230+
231+
val artifactData = LocalData.forArtifact(artifactName, parents)
232+
val artifactContent = NodeContent[T](artifactData, null)
233+
234+
contentMap.update(artifactName, artifactContent)
235+
236+
}
237+
238+
nodeContent.localData.outputs.foreach { output => updateContents(output, isOutput = true) }
239+
nodeContent.localData.inputs.foreach { input => updateContents(input, isOutput = false) }
240+
241+
}
242+
243+
contentMap
244+
}
245+
194246
}
Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,96 @@
11
package ai.chronon.orchestration
22

3+
import ai.chronon.orchestration.utils.StringExtensions.StringOps
4+
5+
/**
6+
* Types relevant to the orchestration layer.
7+
* It is very easy to get raw strings mixed up in the indexing logic.
8+
* So we guard them using case classes. For that reason we have a lot of case classes here.
9+
*/
310
object RepoTypes {
411

12+
/**
13+
* name of the node
14+
* example: group_bys.<team>.<file>.<var>, joins.<team>.<file>.<var>, staging_queries.<team>.<file>.<var>
15+
* and also table.<namespace>.<name> - adding a dummy node for the table makes the code easier to write
16+
*/
517
case class Name(name: String)
618

719
case class Branch(name: String)
820

21+
object Branch {
22+
val main: Branch = Branch("main")
23+
}
24+
25+
/**
26+
* Take the file content string and hashes it
27+
* Whenever this changes the cli will upload the file into the index.
28+
*/
929
case class FileHash(hash: String)
1030

31+
/**
32+
* Local hash represents the computation defined in the file.
33+
* In chronon api, anything field other than metadata is
34+
* considered to impact the computation & consequently the output.
35+
*/
1136
case class LocalHash(hash: String)
1237

38+
/**
39+
* Global hash represents the computation defined in the file and all its dependencies.
40+
* We recursively scan all the parents of the node to compute the global hash.
41+
*
42+
* `global_hash(node) = hash(node.local_hash + node.parents.map(global_hash))`
43+
*/
1344
case class GlobalHash(hash: String)
1445

46+
/**
47+
* Local data represents relevant information for lineage tracking
48+
* that can be computed by parsing a file in *isolation*.
49+
*/
1550
case class LocalData(name: Name, fileHash: FileHash, localHash: LocalHash, inputs: Seq[Name], outputs: Seq[Name])
1651

17-
object Branch {
18-
val main: Branch = Branch("main")
52+
object LocalData {
53+
54+
def forArtifact(name: Name, parents: Seq[Name]): LocalData = {
55+
56+
val nameHash = name.name.md5
57+
58+
LocalData(
59+
name,
60+
FileHash(nameHash),
61+
LocalHash(nameHash),
62+
inputs = parents,
63+
outputs = Seq.empty
64+
)
65+
}
1966
}
2067

68+
/**
69+
* Node content represents the actual data that is stored in the index.
70+
* It is a combination of local data and the actual data that is stored in the index.
71+
*/
2172
case class Version(name: String)
2273

74+
/**
75+
* Content of the compiled file
76+
* Currently TSimpleJsonProtocol serialized StagingQuery, Join, GroupBy thrift objects.
77+
* Python compile.py will serialize user's python code into these objects and
78+
* the [[RepoParser]] will pick them up and sync into [[RepoIndex]].
79+
*/
2380
case class FileContent(content: String) {
24-
def hash: FileHash = FileHash(content.hashCode().toHexString)
81+
def hash: FileHash = FileHash(content.md5)
2582
}
2683

84+
case class NodeContent[T](localData: LocalData, conf: T)
85+
2786
case class Table(name: String)
2887

88+
/**
89+
* To make the code testable, we parameterize the Config with `T`
90+
* You can see how this is used in [[RepoIndexSpec]]
91+
*/
2992
trait ConfProcessor[T] {
30-
def toLocalData(t: T): LocalData
93+
def nodeContents(t: T): Seq[NodeContent[T]]
3194
def parse(name: String, fileContent: FileContent): Seq[T]
3295
}
3396
}

orchestration/src/main/scala/ai/chronon/orchestration/VersionUpdate.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import scala.collection.mutable
77

88
case class VersionUpdate(name: Name, previous: Option[Version], next: Option[Version], main: Option[Version]) {
99

10-
private def isChanged: Boolean = next != main || next != previous
11-
1210
private def toRow: Seq[String] =
1311
Seq(
1412
name.name,
@@ -30,7 +28,6 @@ object VersionUpdate {
3028
.map { name =>
3129
VersionUpdate(name, previous.get(name), next.get(name), main.get(name))
3230
}
33-
.filter(_.isChanged)
3431
.toSeq
3532
.sortBy(_.name.name)
3633
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package ai.chronon.orchestration.utils
2+
3+
import ai.chronon.api.Constants
4+
5+
import java.security.MessageDigest
6+
7+
object StringExtensions {
8+
9+
lazy val digester: ThreadLocal[MessageDigest] = new ThreadLocal[MessageDigest]() {
10+
override def initialValue(): MessageDigest = MessageDigest.getInstance("MD5")
11+
}
12+
13+
implicit class StringOps(s: String) {
14+
def md5: String =
15+
digester
16+
.get()
17+
.digest(s.getBytes(Constants.UTF8))
18+
.map("%02x".format(_))
19+
.mkString
20+
.take(8)
21+
}
22+
}

0 commit comments

Comments
 (0)