Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid allocating spire uint objects during apply agglomerate #6532

Merged
merged 7 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
### Added
- Support for a new mesh file format which allows up to billions of meshes. [#6491](https://github.com/scalableminds/webknossos/pull/6491)
- Remote n5 datasets can now also be explored and added. [#6520](https://github.com/scalableminds/webknossos/pull/6520)
- Improved performance for applying agglomerate mappings on segmentation data. [#6532](https://github.com/scalableminds/webknossos/pull/6532)

### Changed
- Creating tasks in bulk now also supports referencing task types by their summary instead of id. [#6486](https://github.com/scalableminds/webknossos/pull/6486)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
package com.scalableminds.webknossos.datastore.services

import java.nio._
import java.nio.file.{Files, Paths}

import ch.systemsx.cisd.hdf5._
import com.scalableminds.util.io.PathUtils
import com.scalableminds.webknossos.datastore.DataStoreConfig
import com.scalableminds.webknossos.datastore.EditableMapping.{AgglomerateEdge, AgglomerateGraph}
import com.scalableminds.webknossos.datastore.SkeletonTracing.{Edge, SkeletonTracing, Tree}
import com.scalableminds.webknossos.datastore.geometry.Vec3IntProto
import com.scalableminds.webknossos.datastore.helpers.{NodeDefaults, SkeletonTracingDefaults}
import com.scalableminds.webknossos.datastore.models.datasource.ElementClass
import com.scalableminds.webknossos.datastore.models.requests.DataServiceDataRequest
import com.scalableminds.webknossos.datastore.storage._
import com.typesafe.scalalogging.LazyLogging
import javax.inject.Inject
import net.liftweb.common.Box.tryo
import net.liftweb.common.{Box, Failure, Full}
import net.liftweb.util.Helpers.tryo
import org.apache.commons.io.FilenameUtils
import spire.math.{UByte, UInt, ULong, UShort}

import java.nio._
import java.nio.file.{Files, Paths}
import javax.inject.Inject

class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverter with LazyLogging {
private val agglomerateDir = "agglomerates"
Expand All @@ -41,15 +41,11 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
}

def applyAgglomerate(request: DataServiceDataRequest)(data: Array[Byte]): Array[Byte] = {
def byteFunc(buf: ByteBuffer, lon: Long) = buf put lon.toByte
def shortFunc(buf: ByteBuffer, lon: Long) = buf putShort lon.toShort
def intFunc(buf: ByteBuffer, lon: Long) = buf putInt lon.toInt
def longFunc(buf: ByteBuffer, lon: Long) = buf putLong lon

val agglomerateFileKey = AgglomerateFileKey.fromDataRequest(request)

def convertToAgglomerate(input: Array[ULong],
numBytes: Int,
def convertToAgglomerate(input: Array[Long],
bytesPerElement: Int,
bufferFunc: (ByteBuffer, Long) => ByteBuffer): Array[Byte] = {

val cachedAgglomerateFile = agglomerateFileCache.withCache(agglomerateFileKey)(initHDFReader)
Expand All @@ -64,16 +60,33 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
cachedAgglomerateFile.finishAccess()

agglomerateIds
.foldLeft(ByteBuffer.allocate(numBytes * input.length).order(ByteOrder.LITTLE_ENDIAN))(bufferFunc)
.foldLeft(ByteBuffer.allocate(bytesPerElement * input.length).order(ByteOrder.LITTLE_ENDIAN))(bufferFunc)
.array
}

val bytesPerElement = ElementClass.bytesPerElement(request.dataLayer.elementClass)
/* Every value of the segmentation data needs to be converted to Long to then look up the
agglomerate id in the segment-to-agglomerate array.
The value is first converted to the primitive signed number types, and then converted
to Long via uByteToLong, uShortToLong etc, which perform bitwise and to take care of
the unsigned semantics. Using functions avoids allocating intermediate UnsignedInteger objects.
Allocating a fixed-length LongBuffer first is a further performance optimization.
*/
convertData(data, request.dataLayer.elementClass) match {
case data: Array[UByte] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 1, byteFunc)
case data: Array[UShort] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 2, shortFunc)
case data: Array[UInt] => convertToAgglomerate(data.map(e => ULong(e.toLong)), 4, intFunc)
case data: Array[ULong] => convertToAgglomerate(data, 8, longFunc)
case _ => data
case data: Array[Byte] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uByteToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putByte)
case data: Array[Short] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uShortToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putShort)
case data: Array[Int] =>
val longBuffer = LongBuffer.allocate(data.length)
data.foreach(e => longBuffer.put(uIntToLong(e)))
convertToAgglomerate(longBuffer.array, bytesPerElement, putInt)
case data: Array[Long] => convertToAgglomerate(data, bytesPerElement, putLong)
fm3 marked this conversation as resolved.
Show resolved Hide resolved
case _ => data
}
}

Expand Down Expand Up @@ -109,7 +122,7 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte

val defaultCache: Either[AgglomerateIdCache, BoundingBoxCache] =
if (Files.exists(cumsumPath)) {
Right(CumsumParser.parse(cumsumPath.toFile, ULong(config.Datastore.Cache.AgglomerateFile.cumsumMaxReaderRange)))
Right(CumsumParser.parse(cumsumPath.toFile, config.Datastore.Cache.AgglomerateFile.cumsumMaxReaderRange))
} else {
Left(agglomerateIdCache)
}
Expand Down Expand Up @@ -216,8 +229,8 @@ class AgglomerateService @Inject()(config: DataStoreConfig) extends DataConverte
val cachedAgglomerateFile = agglomerateFileCache.withCache(agglomerateFileKey)(initHDFReader)

tryo {
val agglomerateIds = segmentIds.map { segmentId =>
cachedAgglomerateFile.agglomerateIdCache.withCache(ULong(segmentId),
val agglomerateIds = segmentIds.map { segmentId: Long =>
cachedAgglomerateFile.agglomerateIdCache.withCache(segmentId,
cachedAgglomerateFile.reader,
cachedAgglomerateFile.dataset)(readHDF)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,40 @@ package com.scalableminds.webknossos.datastore.services

import com.scalableminds.util.tools.FoxImplicits
import com.scalableminds.webknossos.datastore.models.datasource.ElementClass
import spire.math._
import spire.math.{ULong, _}

import java.nio._
import scala.reflect.ClassTag

trait DataConverter extends FoxImplicits {

def putByte(buf: ByteBuffer, lon: Long): ByteBuffer = buf put lon.toByte
def putShort(buf: ByteBuffer, lon: Long): ByteBuffer = buf putShort lon.toShort
def putInt(buf: ByteBuffer, lon: Long): ByteBuffer = buf putInt lon.toInt
def putLong(buf: ByteBuffer, lon: Long): ByteBuffer = buf putLong lon

def uByteToLong(uByte: Byte): Long = uByte & 0xffL
def uShortToLong(uShort: Short): Long = uShort & 0xffffL
def uIntToLong(uInt: Int): Long = uInt & 0xffffffffL

def convertData(data: Array[Byte],
elementClass: ElementClass.Value,
filterZeroes: Boolean = false): Array[_ >: UByte with UShort with UInt with ULong with Float] =
elementClass: ElementClass.Value): Array[_ >: Byte with Short with Int with Long with Float] =
elementClass match {
case ElementClass.uint8 =>
case ElementClass.uint8 | ElementClass.int8 =>
convertDataImpl[Byte, ByteBuffer](data, DataTypeFunctors[Byte, ByteBuffer](identity, _.get(_), _.toByte))
.map(UByte(_))
.filter(!filterZeroes || _ != UByte(0))
case ElementClass.uint16 =>
case ElementClass.uint16 | ElementClass.int16 =>
convertDataImpl[Short, ShortBuffer](data,
DataTypeFunctors[Short, ShortBuffer](_.asShortBuffer, _.get(_), _.toShort))
.map(UShort(_))
.filter(!filterZeroes || _ != UShort(0))
case ElementClass.uint24 =>
convertDataImpl[Byte, ByteBuffer](data, DataTypeFunctors[Byte, ByteBuffer](identity, _.get(_), _.toByte))
.map(UByte(_))
.filter(!filterZeroes || _ != UByte(0))
case ElementClass.uint32 =>
case ElementClass.uint32 | ElementClass.int32 =>
convertDataImpl[Int, IntBuffer](data, DataTypeFunctors[Int, IntBuffer](_.asIntBuffer, _.get(_), _.toInt))
.map(UInt(_))
.filter(!filterZeroes || _ != UInt(0))
case ElementClass.uint64 =>
case ElementClass.uint64 | ElementClass.int64 =>
convertDataImpl[Long, LongBuffer](data, DataTypeFunctors[Long, LongBuffer](_.asLongBuffer, _.get(_), identity))
.map(ULong(_))
.filter(!filterZeroes || _ != ULong(0))
case ElementClass.float =>
convertDataImpl[Float, FloatBuffer](data,
DataTypeFunctors[Float, FloatBuffer](
_.asFloatBuffer(),
_.get(_),
_.toFloat)).filter(!_.isNaN).filter(!filterZeroes || _ != 0f)
convertDataImpl[Float, FloatBuffer](
data,
DataTypeFunctors[Float, FloatBuffer](_.asFloatBuffer(), _.get(_), _.toFloat))
}

private def convertDataImpl[T: ClassTag, B <: Buffer](data: Array[Byte],
Expand All @@ -50,4 +46,59 @@ trait DataConverter extends FoxImplicits {
dataTypeFunctor.copyDataFn(srcBuffer, dstArray)
dstArray
}

def toUnsigned(data: Array[_ >: Byte with Short with Int with Long with Float])
: Array[_ >: UByte with UShort with UInt with ULong with Float] =
data match {
case d: Array[Byte] => d.map(UByte(_))
case d: Array[Short] => d.map(UShort(_))
case d: Array[Int] => d.map(UInt(_))
case d: Array[Long] => d.map(ULong(_))
case d: Array[Float] => d
}

def filterZeroes(data: Array[_ >: Byte with Short with Int with Long with Float],
skip: Boolean = false): Array[_ >: Byte with Short with Int with Long with Float] =
if (skip) data
else {
val zeroByte = 0.toByte
val zeroShort = 0.toShort
val zeroInt = 0
val zeroLong = 0L
data match {
case d: Array[Byte] => d.filter(_ != zeroByte)
case d: Array[Short] => d.filter(_ != zeroShort)
case d: Array[Int] => d.filter(_ != zeroInt)
case d: Array[Long] => d.filter(_ != zeroLong)
case d: Array[Float] => d.filter(!_.isNaN).filter(_ != 0f)
}
}

def toBytesSpire(typed: Array[_ >: UByte with UShort with UInt with ULong with Float],
elementClass: ElementClass.Value): Array[Byte] = {
val numBytes = ElementClass.bytesPerElement(elementClass)
val byteBuffer = ByteBuffer.allocate(numBytes * typed.length).order(ByteOrder.LITTLE_ENDIAN)
typed match {
case data: Array[UByte] => data.foreach(el => byteBuffer.put(el.signed))
case data: Array[UShort] => data.foreach(el => byteBuffer.putChar(el.signed))
case data: Array[UInt] => data.foreach(el => byteBuffer.putInt(el.signed))
case data: Array[ULong] => data.foreach(el => byteBuffer.putLong(el.signed))
case data: Array[Float] => data.foreach(el => byteBuffer.putFloat(el))
}
byteBuffer.array()
}

def toBytes(typed: Array[_ >: Byte with Short with Int with Long with Float],
elementClass: ElementClass.Value): Array[Byte] = {
val numBytes = ElementClass.bytesPerElement(elementClass)
val byteBuffer = ByteBuffer.allocate(numBytes * typed.length).order(ByteOrder.LITTLE_ENDIAN)
typed match {
case data: Array[Byte] => data.foreach(el => byteBuffer.put(el))
case data: Array[Short] => data.foreach(el => byteBuffer.putShort(el))
case data: Array[Int] => data.foreach(el => byteBuffer.putInt(el))
case data: Array[Long] => data.foreach(el => byteBuffer.putLong(el))
case data: Array[Float] => data.foreach(el => byteBuffer.putFloat(el))
}
byteBuffer.array()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package com.scalableminds.webknossos.datastore.services

import com.google.inject.Inject
import com.scalableminds.util.geometry.Vec3Int
import com.scalableminds.util.tools.{Fox, FoxImplicits, Math}
import com.scalableminds.util.tools.Math
import com.scalableminds.util.tools.{Fox, FoxImplicits}
import com.scalableminds.webknossos.datastore.models.datasource.{DataLayer, DataSource, ElementClass}
import com.scalableminds.webknossos.datastore.models.requests.DataServiceDataRequest
import com.scalableminds.webknossos.datastore.models.{DataRequest, VoxelPosition}
import net.liftweb.common.Full
import play.api.libs.json.{Json, OFormat}
import spire.math._
import spire.math.{UByte, UInt, ULong, UShort}

import scala.annotation.tailrec
import scala.concurrent.ExecutionContext
Expand Down Expand Up @@ -137,14 +138,15 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
} yield positionAndResolutionOpt

def meanAndStdDev(dataSource: DataSource, dataLayer: DataLayer): Fox[(Double, Double)] = {
Fox.successful(5.0, 5.0)

def convertNonZeroDataToDouble(data: Array[Byte], elementClass: ElementClass.Value): Array[Double] =
convertData(data, elementClass, filterZeroes = true) match {
case d: Array[UByte] => d.map(_.toDouble)
case d: Array[UShort] => d.map(_.toDouble)
case d: Array[UInt] => d.map(_.toDouble)
case d: Array[ULong] => d.map(_.toDouble)
case d: Array[Float] => d.map(_.toDouble)
filterZeroes(convertData(data, elementClass)) match {
case d: Array[Byte] => d.map(uByteToLong).map(_.toDouble)
case d: Array[Short] => d.map(uShortToLong).map(_.toDouble)
case d: Array[Int] => d.map(uIntToLong).map(_.toDouble)
case d: Array[Long] => d.map(_.toDouble)
case d: Array[Float] => d.map(_.toDouble)
}

def meanAndStdDevForPositions(positions: List[Vec3Int], resolution: Vec3Int): Fox[(Double, Double)] =
Expand Down Expand Up @@ -199,7 +201,9 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
}
if (isUint24) {
val listOfCounts = counts.grouped(256).toList
listOfCounts.map(counts => { counts(0) = 0; Histogram(counts, counts.sum.toInt, extrema._1, extrema._2) })
listOfCounts.map(counts => {
counts(0) = 0; Histogram(counts, counts.sum.toInt, extrema._1, extrema._2)
})
} else
List(Histogram(counts, data.length, extrema._1, extrema._2))
}
Expand All @@ -208,7 +212,7 @@ class FindDataService @Inject()(dataServicesHolder: BinaryDataServiceHolder)(imp
for {
dataConcatenated <- getConcatenatedDataFor(dataSource, dataLayer, positions, resolution) ?~> "dataSet.noData"
isUint24 = dataLayer.elementClass == ElementClass.uint24
convertedData = convertData(dataConcatenated, dataLayer.elementClass, filterZeroes = !isUint24)
convertedData = toUnsigned(filterZeroes(convertData(dataConcatenated, dataLayer.elementClass), skip = isUint24))
} yield calculateHistogramValues(convertedData, dataLayer.bytesPerElement, isUint24)

if (dataLayer.resolutions.nonEmpty)
Expand Down
Loading