|
| 1 | +package ai.chronon.orchestration |
| 2 | + |
| 3 | +import ai.chronon.orchestration.RepoIndex._ |
| 4 | +import ai.chronon.orchestration.RepoTypes._ |
| 5 | +import ai.chronon.orchestration.utils.CollectionExtensions.IteratorExtensions |
| 6 | +import ai.chronon.orchestration.utils.SequenceMap |
| 7 | + |
| 8 | +import scala.collection.mutable |
| 9 | + |
| 10 | +class RepoIndex[T](proc: ConfProcessor[T]) { |
| 11 | + |
| 12 | + // first pass updates |
| 13 | + private val branchToFileHash: TriMap[Branch, Name, FileHash] = mutable.Map.empty |
| 14 | + private val fileHashToContent: TriMap[Name, FileHash, NodeContent[T]] = mutable.Map.empty |
| 15 | + |
| 16 | + // second pass updates |
| 17 | + private val branchVersionIndex: TriMap[Branch, Name, Version] = mutable.Map.empty |
| 18 | + private val versionSequencer: SequenceMap[Name, GlobalHash] = new SequenceMap[Name, GlobalHash] |
| 19 | + |
| 20 | + def addNodes(fileHashes: mutable.Map[Name, FileHash], |
| 21 | + nodes: Seq[T], |
| 22 | + branch: Branch, |
| 23 | + dryRun: Boolean = true): Seq[VersionUpdate] = { |
| 24 | + |
| 25 | + val newContents = nodes.map { node => |
| 26 | + val data = proc.toLocalData(node) |
| 27 | + val nodeContent = NodeContent(data, node) |
| 28 | + |
| 29 | + require(data.fileHash == fileHashes(data.name), s"File hash mismatch for ${data.name}") |
| 30 | + |
| 31 | + data.name -> (data.fileHash -> nodeContent) |
| 32 | + |
| 33 | + }.toMap |
| 34 | + |
| 35 | + def getContents(name: Name, fileHash: FileHash): NodeContent[T] = { |
| 36 | + |
| 37 | + val incomingContents = newContents.get(name).map(_._2) |
| 38 | + |
| 39 | + lazy val existingContents = fileHashToContent |
| 40 | + .get(name) |
| 41 | + .flatMap(_.get(fileHash)) |
| 42 | + |
| 43 | + incomingContents.orElse(existingContents).get |
| 44 | + } |
| 45 | + |
| 46 | + val globalHashes = mutable.Map.empty[Name, GlobalHash] |
| 47 | + |
| 48 | + // memoizes into globalHashes and recursively computes global hash from parents |
| 49 | + def computeGlobalHash(name: Name): GlobalHash = { |
| 50 | + |
| 51 | + if (globalHashes.contains(name)) return globalHashes(name) |
| 52 | + |
| 53 | + val fileHash = fileHashes(name) |
| 54 | + val content = getContents(name, fileHash) |
| 55 | + |
| 56 | + val localHash = content.localData.localHash |
| 57 | + val parents = content.localData.inputs |
| 58 | + |
| 59 | + // recursively compute parent hashes |
| 60 | + val parentHashes = parents |
| 61 | + .map { parent => |
| 62 | + val parentHash = globalHashes.getOrElse(parent, computeGlobalHash(parent)).hash |
| 63 | + s"${parent.name}:$parentHash" |
| 64 | + |
| 65 | + } |
| 66 | + .mkString(",") |
| 67 | + |
| 68 | + // combine parent hashcode with local hash |
| 69 | + val codeString = s"node=${name.name}:$localHash|parents=$parentHashes" |
| 70 | + val globalHash = GlobalHash(codeString.hashCode().toHexString) |
| 71 | + |
| 72 | + globalHashes.update(name, globalHash) |
| 73 | + globalHash |
| 74 | + } |
| 75 | + |
| 76 | + val newVersions = mutable.Map.empty[Name, Version] |
| 77 | + |
| 78 | + fileHashes.foreach { |
| 79 | + case (name, _) => |
| 80 | + val globalHash = computeGlobalHash(name) |
| 81 | + |
| 82 | + val versionIndex = versionSequencer.potentialIndex(name, globalHash) |
| 83 | + newVersions.update(name, Version("v" + versionIndex.toString)) |
| 84 | + } |
| 85 | + |
| 86 | + val existingVersions = branchVersionIndex.getOrElse(branch, mutable.Map.empty) |
| 87 | + val mainVersions = branchVersionIndex.getOrElse(Branch.main, mutable.Map.empty) |
| 88 | + |
| 89 | + val versionUpdates = VersionUpdate.join(newVersions, existingVersions, mainVersions) |
| 90 | + VersionUpdate.print(versionUpdates) |
| 91 | + |
| 92 | + if (!dryRun) { |
| 93 | + |
| 94 | + newContents.foreach { |
| 95 | + case (name, (fileHash, content)) => update(fileHashToContent, name, fileHash, content) |
| 96 | + } |
| 97 | + |
| 98 | + branchToFileHash.update(branch, fileHashes) |
| 99 | + branchVersionIndex.update(branch, newVersions) |
| 100 | + |
| 101 | + } |
| 102 | + |
| 103 | + versionUpdates |
| 104 | + } |
| 105 | + |
| 106 | + // returns the contents of the files not present in the index |
| 107 | + def diff(incomingFileHashes: mutable.Map[Name, FileHash]): Seq[Name] = { |
| 108 | + |
| 109 | + incomingFileHashes |
| 110 | + .filter { |
| 111 | + case (name, incomingHash) => |
| 112 | + val fileHashMap = fileHashToContent.get(name) |
| 113 | + |
| 114 | + lazy val nameAbsentInIndex = fileHashMap.isEmpty |
| 115 | + lazy val fileHashAbsentForName = !fileHashMap.get.contains(incomingHash) |
| 116 | + |
| 117 | + nameAbsentInIndex || fileHashAbsentForName |
| 118 | + |
| 119 | + } |
| 120 | + .keys |
| 121 | + .toSeq |
| 122 | + } |
| 123 | + |
| 124 | + def pruneBranch(branch: Branch): Unit = { |
| 125 | + |
| 126 | + branchToFileHash.remove(branch) |
| 127 | + |
| 128 | + pruneContents() |
| 129 | + } |
| 130 | + |
| 131 | + private def pruneContents(): Unit = { |
| 132 | + |
| 133 | + // collect unique hashes per name from every branch |
| 134 | + val validHashes: mutable.Map[Name, mutable.HashSet[FileHash]] = innerValues(branchToFileHash) |
| 135 | + |
| 136 | + fileHashToContent.retain { |
| 137 | + case (name, fileHashMap) => |
| 138 | + fileHashMap.retain { |
| 139 | + |
| 140 | + case (fileHash, _) => |
| 141 | + validHashes.get(name) match { |
| 142 | + case None => false // no branch has this name |
| 143 | + case Some(hashes) => hashes.contains(fileHash) // this branch has this fileHash |
| 144 | + } |
| 145 | + |
| 146 | + } |
| 147 | + |
| 148 | + fileHashMap.nonEmpty |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + def addFiles(fileHashes: mutable.Map[Name, FileHash], updatedFiles: Map[String, String], branch: Branch): Unit = { |
| 153 | + |
| 154 | + val nodes: Seq[T] = updatedFiles.iterator.flatMap { |
| 155 | + case (name, content) => |
| 156 | + proc.parse(name, FileContent(content)) |
| 157 | + }.distinct |
| 158 | + |
| 159 | + addNodes(fileHashes, nodes, branch) |
| 160 | + } |
| 161 | + |
| 162 | +} |
| 163 | + |
| 164 | +object RepoIndex { |
| 165 | + |
| 166 | + private case class NodeContent[T](localData: LocalData, conf: T) |
| 167 | + |
| 168 | + private type TriMap[K1, K2, V] = mutable.Map[K1, mutable.Map[K2, V]] |
| 169 | + |
| 170 | + private def update[K1, K2, V](map: TriMap[K1, K2, V], k1: K1, k2: K2, v: V): Unit = |
| 171 | + map.getOrElseUpdate(k1, mutable.Map.empty).update(k2, v) |
| 172 | + |
| 173 | + private def innerValues[K1, K2, V](map: TriMap[K1, K2, V]): mutable.Map[K2, mutable.HashSet[V]] = { |
| 174 | + val result = mutable.Map.empty[K2, mutable.HashSet[V]] |
| 175 | + map.values.foreach { innerMap => |
| 176 | + innerMap.foreach { |
| 177 | + case (k2, v) => |
| 178 | + result.getOrElseUpdate(k2, mutable.HashSet.empty).add(v) |
| 179 | + } |
| 180 | + } |
| 181 | + result |
| 182 | + } |
| 183 | + |
| 184 | +} |
0 commit comments