Skip to content

Commit 8097484

Browse files
zachgkdrivanov
authored andcommitted
Faster Scala NDArray to BufferedImage function (apache#16219)
1 parent 1cf7ed8 commit 8097484

File tree

1 file changed

+6
-4
lines changed
  • scala-package/core/src/main/scala/org/apache/mxnet

1 file changed

+6
-4
lines changed

scala-package/core/src/main/scala/org/apache/mxnet/Image.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,18 @@ object Image {
174174
def toImage(src: NDArray): BufferedImage = {
175175
require(src.dtype == DType.UInt8, "The input NDArray must be bytes")
176176
require(src.shape.length == 3, "The input should contains height, width and channel")
177+
require(src.shape(2) == 3, "There should be three channels: RGB")
177178
val height = src.shape.get(0)
178179
val width = src.shape.get(1)
179180
val img = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
181+
val arr = src.toArray
180182
(0 until height).par.foreach(r => {
181183
(0 until width).par.foreach(c => {
182-
val arr = src.at(r).at(c).toArray
183184
// NDArray in RGB
184-
val red = arr(0).toByte & 0xFF
185-
val green = arr(1).toByte & 0xFF
186-
val blue = arr(2).toByte & 0xFF
185+
val cellIndex = r * width * 3 + c * 3
186+
val red = arr(cellIndex).toByte & 0xFF
187+
val green = arr(cellIndex + 1).toByte & 0xFF
188+
val blue = arr(cellIndex + 2).toByte & 0xFF
187189
val rgb = (red << 16) | (green << 8) | blue
188190
img.setRGB(c, r, rgb)
189191
})

0 commit comments

Comments
 (0)