Skip to content

Commit 1d36659

Browse files
committed
[query] Extract Backend Methods called from Python into Py4JBackendExtensions
1 parent bf3c9da commit 1d36659

File tree

18 files changed

+406
-434
lines changed

18 files changed

+406
-434
lines changed

hail/python/hail/backend/py4j_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,17 @@ def _rpc(self, action, payload) -> Tuple[bytes, Optional[dict]]:
237237

238238
def persist_expression(self, expr):
239239
t = expr.dtype
240-
return construct_expr(JavaIR(t, self._jbackend.executeLiteral(self._render_ir(expr._ir))), t)
240+
return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t)
241241

242242
def _is_registered_ir_function_name(self, name: str) -> bool:
243243
return name in self._registered_ir_function_names
244244

245245
def set_flags(self, **flags: Mapping[str, str]):
246-
available = self._jbackend.availableFlags()
246+
available = self._jbackend.pyAvailableFlags()
247247
invalid = []
248248
for flag, value in flags.items():
249249
if flag in available:
250-
self._jbackend.setFlag(flag, value)
250+
self._jbackend.pySetFlag(flag, value)
251251
else:
252252
invalid.append(flag)
253253
if len(invalid) != 0:
@@ -256,7 +256,7 @@ def set_flags(self, **flags: Mapping[str, str]):
256256
)
257257

258258
def get_flags(self, *flags) -> Mapping[str, str]:
259-
return {flag: self._jbackend.getFlag(flag) for flag in flags}
259+
return {flag: self._jbackend.pyGetFlag(flag) for flag in flags}
260260

261261
def _add_reference_to_scala_backend(self, rg):
262262
self._jbackend.pyAddReference(orjson.dumps(rg._config).decode('utf-8'))

hail/python/hail/ir/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3880,7 +3880,7 @@ def __del__(self):
38803880
if Env._hc:
38813881
backend = Env.backend()
38823882
assert isinstance(backend, Py4JBackend)
3883-
backend._jbackend.removeJavaIR(self._id)
3883+
backend._jbackend.pyRemoveJavaIR(self._id)
38843884

38853885

38863886
class JavaIR(IR):

hail/python/hail/ir/table_ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,4 +1215,4 @@ def __del__(self):
12151215
if Env._hc:
12161216
backend = Env.backend()
12171217
assert isinstance(backend, Py4JBackend)
1218-
backend._jbackend.removeJavaIR(self._id)
1218+
backend._jbackend.pyRemoveJavaIR(self._id)

hail/python/test/hail/genetics/test_reference_genome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def assert_rg_loaded_correctly(name):
194194
# loading different reference genome with same name should fail
195195
# (different `test_rg_o` definition)
196196
with pytest.raises(FatalError):
197-
hl.read_matrix_table(resource('custom_references_2.t')).count()
197+
hl.read_table(resource('custom_references_2.t')).count()
198198

199199
assert hl.read_matrix_table(resource('custom_references.mt')).count_rows() == 14
200200
assert_rg_loaded_correctly('test_rg_1')

hail/src/main/scala/is/hail/backend/Backend.scala

Lines changed: 19 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package is.hail.backend
22

33
import is.hail.asm4s._
4+
import is.hail.backend.Backend.jsonToBytes
45
import is.hail.backend.spark.SparkBackend
56
import is.hail.expr.ir.{
67
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
78
SortField, TableIR, TableReader,
89
}
9-
import is.hail.expr.ir.functions.IRFunctionRegistry
1010
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
1111
import is.hail.io.{BufferSpec, TypedCodecSpec}
1212
import is.hail.io.fs._
@@ -20,16 +20,14 @@ import is.hail.types.virtual.{BlockMatrixType, TFloat64}
2020
import is.hail.utils._
2121
import is.hail.variant.ReferenceGenome
2222

23-
import scala.collection.JavaConverters._
2423
import scala.collection.mutable
2524
import scala.reflect.ClassTag
2625

2726
import java.io._
2827
import java.nio.charset.StandardCharsets
2928

30-
import com.fasterxml.jackson.core.StreamReadConstraints
3129
import org.json4s._
32-
import org.json4s.jackson.{JsonMethods, Serialization}
30+
import org.json4s.jackson.JsonMethods
3331
import sourcecode.Enclosing
3432

3533
object Backend {
@@ -41,13 +39,6 @@ object Backend {
4139
s"hail_query_$id"
4240
}
4341

44-
private var irID: Int = 0
45-
46-
def nextIRID(): Int = {
47-
irID += 1
48-
irID
49-
}
50-
5142
def encodeToOutputStream(
5243
ctx: ExecuteContext,
5344
t: PTuple,
@@ -66,6 +57,9 @@ object Backend {
6657
assert(t.isFieldDefined(off, 0))
6758
codec.encode(ctx, elementType, t.loadField(off, 0), os)
6859
}
60+
61+
def jsonToBytes(f: => JValue): Array[Byte] =
62+
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
6963
}
7064

7165
abstract class BroadcastValue[T] { def value: T }
@@ -75,28 +69,8 @@ trait BackendContext {
7569
}
7670

7771
abstract class Backend extends Closeable {
78-
// From https://github.com/hail-is/hail/issues/14580 :
79-
// IR can get quite big, especially as it can contain an arbitrary
80-
// amount of encoded literals from the user's python session. This
81-
// was a (controversial) restriction imposed by Jackson and should be lifted.
82-
//
83-
// We remove this restriction for all backends, and we do so here, in the
84-
// constructor since constructing a backend is one of the first things that
85-
// happens and this constraint should be overrided as early as possible.
86-
StreamReadConstraints.overrideDefaultStreamReadConstraints(
87-
StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()
88-
)
89-
9072
val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
9173

92-
protected[this] def addJavaIR(ir: BaseIR): Int = {
93-
val id = Backend.nextIRID()
94-
persistedIR += (id -> ir)
95-
id
96-
}
97-
98-
def removeJavaIR(id: Int): Unit = persistedIR.remove(id)
99-
10074
def defaultParallelism: Int
10175

10276
def canExecuteParallelTasksOnDriver: Boolean = true
@@ -131,30 +105,7 @@ abstract class Backend extends Closeable {
131105
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
132106
: CompiledFunction[T]
133107

134-
var references: Map[String, ReferenceGenome] = Map.empty
135-
136-
def addDefaultReferences(): Unit =
137-
references = ReferenceGenome.builtinReferences()
138-
139-
def addReference(rg: ReferenceGenome): Unit = {
140-
references.get(rg.name) match {
141-
case Some(rg2) =>
142-
if (rg != rg2) {
143-
fatal(
144-
s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " +
145-
s"@1",
146-
references.keys.truncatable("\n "),
147-
)
148-
}
149-
case None =>
150-
references += (rg.name -> rg)
151-
}
152-
}
153-
154-
def hasReference(name: String) = references.contains(name)
155-
156-
def removeReference(name: String): Unit =
157-
references -= name
108+
def references: mutable.Map[String, ReferenceGenome]
158109

159110
def lowerDistributedSort(
160111
ctx: ExecuteContext,
@@ -189,9 +140,6 @@ abstract class Backend extends Closeable {
189140

190141
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T
191142

192-
private[this] def jsonToBytes(f: => JValue): Array[Byte] =
193-
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
194-
195143
final def valueType(s: String): Array[Byte] =
196144
jsonToBytes {
197145
withExecuteContext { ctx =>
@@ -220,15 +168,7 @@ abstract class Backend extends Closeable {
220168
}
221169
}
222170

223-
def loadReferencesFromDataset(path: String): Array[Byte] = {
224-
withExecuteContext { ctx =>
225-
val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path)
226-
rgs.foreach(addReference)
227-
228-
implicit val formats: Formats = defaultJSONFormats
229-
Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8)
230-
}
231-
}
171+
def loadReferencesFromDataset(path: String): Array[Byte]
232172

233173
def fromFASTAFile(
234174
name: String,
@@ -240,18 +180,22 @@ abstract class Backend extends Closeable {
240180
parInput: Array[String],
241181
): Array[Byte] =
242182
withExecuteContext { ctx =>
243-
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
244-
xContigs, yContigs, mtContigs, parInput)
245-
rg.toJSONString.getBytes(StandardCharsets.UTF_8)
183+
jsonToBytes {
184+
Extraction.decompose {
185+
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
186+
xContigs, yContigs, mtContigs, parInput).toJSON
187+
}(defaultJSONFormats)
188+
}
246189
}
247190

248-
def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes {
191+
def parseVCFMetadata(path: String): Array[Byte] =
249192
withExecuteContext { ctx =>
250-
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
251-
implicit val formats = defaultJSONFormats
252-
Extraction.decompose(metadata)
193+
jsonToBytes {
194+
Extraction.decompose {
195+
LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
196+
}(defaultJSONFormats)
197+
}
253198
}
254-
}
255199

256200
def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
257201
: Array[Byte] =
@@ -261,27 +205,6 @@ abstract class Backend extends Closeable {
261205
)
262206
}
263207

264-
def pyRegisterIR(
265-
name: String,
266-
typeParamStrs: java.util.ArrayList[String],
267-
argNameStrs: java.util.ArrayList[String],
268-
argTypeStrs: java.util.ArrayList[String],
269-
returnType: String,
270-
bodyStr: String,
271-
): Unit = {
272-
withExecuteContext { ctx =>
273-
IRFunctionRegistry.registerIR(
274-
ctx,
275-
name,
276-
typeParamStrs.asScala.toArray,
277-
argNameStrs.asScala.toArray,
278-
argTypeStrs.asScala.toArray,
279-
returnType,
280-
bodyStr,
281-
)
282-
}
283-
}
284-
285208
def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
286209
}
287210

hail/src/main/scala/is/hail/backend/BackendServer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
112112
}
113113
return
114114
}
115+
115116
val response: Array[Byte] = exchange.getRequestURI.getPath match {
116117
case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir)
117118
case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir)

hail/src/main/scala/is/hail/backend/ExecuteContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ExecuteContext(
128128
)
129129
}
130130

131-
val stateManager = HailStateManager(backend.references)
131+
def stateManager = HailStateManager(backend.references.toMap)
132132

133133
val tempFileManager: TempFileManager =
134134
if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs)

0 commit comments

Comments
 (0)