Skip to content

Commit

Permalink
Auto-Select via SAM (#7051)
Browse files Browse the repository at this point in the history
* WIP: Auto-Select via SAM

* load data from datastore

* send element class to sam server

* mag handling

* rough sam integration into frontend

* use embedding route from backend instead of fetching precomputed embedding

* small unrelated bug fix for canceling poll under certain race condition

* switch for dynamic vs hardcoded embedding

* better error handling and disable fetching of hardcoded embedding

* only use user-requested bbox when applying results for better performance

* request embedding for actual user position

* implement caching and reuse of embedding

* temporarily disable most CI checks

* restore application.conf setting for slick and swagger; add another comment to explain impact on dev instances

* make it work in arbitrary mags

* integrate mag into embedding cache

* show busy indicator when loading embedding

* remove the entire heuristic-based quickselect code

* refactor and clean up so that ML and heuristic based quick select are both usable

* allow to select AI or not for quick select in UI

* fix inverted style

* only allow quick-select with SAM on wkorg

* make sam select compatible with other viewports

* refactor embedding request

* fix linting

* Revert "temporarily disable most CI checks"

This reverts commit 3f28c99.

* prefetch embedding as soon as user presses mouse down

* use segmentAnythingEnabled instead of isWkorgInstance

* update snapshots

* also prefetch ORT session

* use camel case in infer code

* optimize extraction of mask

* rename some vars

* adapt analytics event

* update changelog

* update docs

* fix some deprecations

* don't cache failed embeddings; only support uint8 in ai mode

* re-add assertion, disable in conf by default

* update snapshot (segmentAnything disabled by default)

* don't clamp min coord of bounding boxes to 0

* remove unnecessary V3.max calls

* pr feedback

* catch error better if wasm cannot be loaded

* avoid redundant error toast

* further simplication for bbox

* add assertion for bbox < 1024**2

* slice cache when adding to it instead of when accessing it

* remove time measurement code

* align user bbox to mag before using it to ai-select

* more bbox alignment (except for geometry)

* fix onnxCoord interpretation for yz viewport

* fix mag-alignment logic by rewriting the alignWithMag function to provide more strategies

* fix usage of wrong cache entries because of zero-volume bounding boxes that led to positive containment checks

* fix that QuickSelectGeometry would be invisible when the third dimension's fractional was below 0.5

* extrude bounding boxes by correctly mag-adapted depth

* add comments to inference code

* prefetch session as soon as quick select tool is activated

* allow to cancel quick select with escape while drawing rectangle

* fix invisible geometry

* clean up console.logs

* fix linting

* include intensity min/max and element class in byte array sent to sam server

* assert min/max intensity is supplied if element class is float or double

* fix frontend typecheck

* send intensity range to embedding endpoint if layer is not uint8

* backend error messages

* use 4 bytes each for min/max of intensity range, not 8

* only show center marker when using old quick-select mode; fix scale of center marker

* add comment

* send metadata with correct endianness

* take layer name into account when caching embedding

---------

Co-authored-by: Philipp Otto <[email protected]>
Co-authored-by: Philipp Otto <[email protected]>
Co-authored-by: Norman Rzepka <[email protected]>
  • Loading branch information
4 people authored May 17, 2023
1 parent 53e6686 commit 155e94e
Show file tree
Hide file tree
Showing 46 changed files with 1,853 additions and 1,039 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
### Added
- Added segment groups so that segments can be organized in a hierarchy (similar to skeletons). [#6966](https://github.com/scalableminds/webknossos/pull/6966)
- In addition to drag and drop, the selected tree(s) in the Skeleton tab can also be moved into another group by right-clicking the target group and selecting "Move selected tree(s) here". [#7005](https://github.com/scalableminds/webknossos/pull/7005)
- Added a machine-learning based quick select mode. Activate it via the "AI" button in the toolbar after selecting the quick-select tool. [#7051](https://github.com/scalableminds/webknossos/pull/7051)
- Added support for remote datasets encoded with [brotli](https://datatracker.ietf.org/doc/html/rfc7932). [#7041](https://github.com/scalableminds/webknossos/pull/7041)
- Teams can be edited more straight-forwardly in a popup in the team edit page. [#7043](https://github.com/scalableminds/webknossos/pull/7043)
- Annotations with Editable Mappings (a.k.a Supervoxel Proofreading) can now be merged. [#7026](https://github.com/scalableminds/webknossos/pull/7026)
Expand Down
59 changes: 57 additions & 2 deletions app/controllers/DataSetController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import com.scalableminds.util.accesscontext.{DBAccessContext, GlobalAccessContex
import com.scalableminds.util.geometry.{BoundingBox, Vec3Int}
import com.scalableminds.util.time.Instant
import com.scalableminds.util.tools.{Fox, JsonHelper, Math}
import com.scalableminds.webknossos.datastore.models.datasource.{DataLayer, DataLayerLike, GenericDataSource}
import com.scalableminds.webknossos.datastore.models.datasource.{
DataLayer,
DataLayerLike,
ElementClass,
GenericDataSource
}
import io.swagger.annotations._
import models.analytics.{AnalyticsService, ChangeDatasetSettingsEvent, OpenDatasetEvent}
import models.binary._
Expand All @@ -20,7 +25,7 @@ import play.api.i18n.{Messages, MessagesProvider}
import play.api.libs.functional.syntax._
import play.api.libs.json._
import play.api.mvc.{Action, AnyContent, PlayBodyParsers}
import utils.ObjectId
import utils.{ObjectId, WkConf}

import javax.inject.Inject
import scala.collection.mutable.ListBuffer
Expand All @@ -42,6 +47,15 @@ object DatasetUpdateParameters extends TristateOptionJsonHelper {
Json.configured(tristateOptionParsing).format[DatasetUpdateParameters]
}

case class SegmentAnythingEmbeddingParameters(
mag: Vec3Int,
boundingBox: BoundingBox
)

object SegmentAnythingEmbeddingParameters {
implicit val jsonFormat: Format[SegmentAnythingEmbeddingParameters] = Json.format[SegmentAnythingEmbeddingParameters]
}

@Api
class DataSetController @Inject()(userService: UserService,
userDAO: UserDAO,
Expand All @@ -50,8 +64,10 @@ class DataSetController @Inject()(userService: UserService,
dataSetLastUsedTimesDAO: DataSetLastUsedTimesDAO,
organizationDAO: OrganizationDAO,
teamDAO: TeamDAO,
wKRemoteSegmentAnythingClient: WKRemoteSegmentAnythingClient,
teamService: TeamService,
dataSetDAO: DataSetDAO,
conf: WkConf,
analyticsService: AnalyticsService,
mailchimpClient: MailchimpClient,
exploreRemoteLayerService: ExploreRemoteLayerService,
Expand Down Expand Up @@ -519,4 +535,43 @@ Expects:
case _ => Messages("dataSet.notFoundConsiderLogin", dataSetName)
}

@ApiOperation(hidden = true, value = "")
def segmentAnythingEmbedding(organizationName: String,
dataSetName: String,
dataLayerName: String,
intensityMin: Option[Float],
intensityMax: Option[Float]): Action[SegmentAnythingEmbeddingParameters] =
sil.SecuredAction.async(validateJson[SegmentAnythingEmbeddingParameters]) { implicit request =>
log() {
for {
_ <- bool2Fox(conf.Features.segmentAnythingEnabled) ?~> "segmentAnything.notEnabled"
_ <- bool2Fox(conf.SegmentAnything.uri.nonEmpty) ?~> "segmentAnything.noUri"
dataset <- dataSetDAO.findOneByNameAndOrganizationName(dataSetName, organizationName) ?~> notFoundMessage(
dataSetName) ~> NOT_FOUND
dataSource <- dataSetService.dataSourceFor(dataset) ?~> "dataSource.notFound" ~> NOT_FOUND
usableDataSource <- dataSource.toUsable ?~> "dataSet.notImported"
dataLayer <- usableDataSource.dataLayers.find(_.name == dataLayerName) ?~> "dataSet.noLayers"
datastoreClient <- dataSetService.clientFor(dataset)(GlobalAccessContext)
targetMagBbox: BoundingBox = request.body.boundingBox / request.body.mag
_ <- bool2Fox(targetMagBbox.dimensions.sorted == Vec3Int(1, 1024, 1024)) ?~> s"Target-mag bbox must be sized 1024×1024×1 (or transposed), got ${targetMagBbox.dimensions}"
data <- datastoreClient.getLayerData(organizationName,
dataset,
dataLayer.name,
request.body.boundingBox,
request.body.mag) ?~> "segmentAnything.getData.failed"
_ = logger.debug(
s"Sending ${data.length} bytes to SAM server, element class is ${dataLayer.elementClass}, range: $intensityMin-$intensityMax...")
_ <- bool2Fox(
!(dataLayer.elementClass == ElementClass.float || dataLayer.elementClass == ElementClass.double) || (intensityMin.isDefined && intensityMax.isDefined)) ?~> "For float and double data, a supplied intensity range is required."
embedding <- wKRemoteSegmentAnythingClient.getEmbedding(
data,
dataLayer.elementClass,
intensityMin,
intensityMax) ?~> "segmentAnything.getEmbedding.failed"
_ = logger.debug(
s"Received ${embedding.length} bytes of embedding from SAM server, forwarding to front-end...")
} yield Ok(embedding)
}
}

}
7 changes: 3 additions & 4 deletions app/models/binary/DataSetService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import com.scalableminds.webknossos.datastore.models.datasource.{
import com.scalableminds.webknossos.datastore.rpc.RPC
import com.scalableminds.webknossos.datastore.storage.TemporaryStore
import com.typesafe.scalalogging.LazyLogging
import models.folder.{FolderDAO, FolderService}
import models.folder.FolderDAO

import javax.inject.Inject
import models.job.WorkerDAO
Expand All @@ -41,7 +41,6 @@ class DataSetService @Inject()(organizationDAO: OrganizationDAO,
dataStoreService: DataStoreService,
teamService: TeamService,
userService: UserService,
folderService: FolderService,
val thumbnailCache: TemporaryStore[String, Array[Byte]],
rpc: RPC,
conf: WkConf)(implicit ec: ExecutionContext)
Expand All @@ -66,7 +65,7 @@ class DataSetService @Inject()(organizationDAO: OrganizationDAO,
createDataSet(dataStore, organizationName, unreportedDatasource)
}

def createDataSet(
private def createDataSet(
dataStore: DataStore,
owningOrganization: String,
dataSource: InboxDataSource,
Expand Down Expand Up @@ -256,7 +255,7 @@ class DataSetService @Inject()(organizationDAO: OrganizationDAO,
dataStore <- dataStoreFor(dataSet)
} yield new WKRemoteDataStoreClient(dataStore, rpc)

def lastUsedTimeFor(_dataSet: ObjectId, userOpt: Option[User]): Fox[Instant] =
private def lastUsedTimeFor(_dataSet: ObjectId, userOpt: Option[User]): Fox[Instant] =
userOpt match {
case Some(user) =>
(for {
Expand Down
22 changes: 21 additions & 1 deletion app/models/binary/WKRemoteDataStoreClient.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package models.binary

import com.scalableminds.util.geometry.Vec3Int
import com.scalableminds.util.geometry.{BoundingBox, Vec3Int}
import com.scalableminds.util.tools.Fox
import com.scalableminds.webknossos.datastore.rpc.RPC
import com.scalableminds.webknossos.datastore.services.DirectoryStorageReport
Expand Down Expand Up @@ -29,6 +29,26 @@ class WKRemoteDataStoreClient(dataStore: DataStore, rpc: RPC) extends LazyLoggin
.getWithBytesResponse
}

def getLayerData(organizationName: String,
dataset: DataSet,
layerName: String,
mag1BoundingBox: BoundingBox,
mag: Vec3Int): Fox[Array[Byte]] = {
val targetMagBoundingBox = mag1BoundingBox / mag
logger.debug(s"Fetching raw data. Mag $mag, mag1 bbox: $mag1BoundingBox, target-mag bbox: $targetMagBoundingBox")
rpc(
s"${dataStore.url}/data/datasets/${urlEncode(organizationName)}/${dataset.urlEncodedName}/layers/$layerName/data")
.addQueryString("token" -> RpcTokenHolder.webKnossosToken)
.addQueryString("mag" -> mag.toMagLiteral())
.addQueryString("x" -> mag1BoundingBox.topLeft.x.toString)
.addQueryString("y" -> mag1BoundingBox.topLeft.y.toString)
.addQueryString("z" -> mag1BoundingBox.topLeft.z.toString)
.addQueryString("width" -> targetMagBoundingBox.width.toString)
.addQueryString("height" -> targetMagBoundingBox.height.toString)
.addQueryString("depth" -> targetMagBoundingBox.depth.toString)
.getWithBytesResponse
}

def findPositionWithData(organizationName: String, dataSet: DataSet, dataLayerName: String): Fox[JsObject] =
rpc(
s"${dataStore.url}/data/datasets/${urlEncode(organizationName)}/${dataSet.urlEncodedName}/layers/$dataLayerName/findData")
Expand Down
29 changes: 29 additions & 0 deletions app/models/binary/WKRemoteSegmentAnythingClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package models.binary

import com.scalableminds.util.tools.Fox
import com.scalableminds.webknossos.datastore.rpc.RPC
import com.scalableminds.webknossos.datastore.models.datasource.ElementClass
import utils.WkConf

import java.nio.{ByteBuffer, ByteOrder}
import javax.inject.Inject

class WKRemoteSegmentAnythingClient @Inject()(rpc: RPC, conf: WkConf) {
def getEmbedding(imageData: Array[Byte],
elementClass: ElementClass.Value,
intensityMin: Option[Float],
intensityMax: Option[Float]): Fox[Array[Byte]] = {
val metadataLengthInBytes = 1 + 1 + 4 + 4
val buffer = ByteBuffer.allocate(metadataLengthInBytes + imageData.length)
buffer.put(ElementClass.encodeAsByte(elementClass))
buffer.put(if (intensityMin.isDefined && intensityMax.isDefined) 1.toByte else 0.toByte)
buffer.order(ByteOrder.LITTLE_ENDIAN).putFloat(intensityMin.getOrElse(0.0f))
buffer.order(ByteOrder.LITTLE_ENDIAN).putFloat(intensityMax.getOrElse(0.0f))
val imageWithMetadata = buffer.array()
System.arraycopy(imageData, 0, imageWithMetadata, metadataLengthInBytes, imageData.length)
rpc(s"${conf.SegmentAnything.uri}/predictions/sam_vit_l")
.addQueryString("elementClass" -> elementClass.toString)
.postBytesWithBytesResponse(imageWithMetadata)
}

}
8 changes: 7 additions & 1 deletion app/utils/WkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class WkConf @Inject()(configuration: Configuration) extends ConfigReader with L
val exportTiffMaxVolumeMVx: Long = get[Long]("features.exportTiffMaxVolumeMVx")
val exportTiffMaxEdgeLengthVx: Long = get[Long]("features.exportTiffMaxEdgeLengthVx")
val openIdConnectEnabled: Boolean = get[Boolean]("features.openIdConnectEnabled")
val segmentAnythingEnabled: Boolean = get[Boolean]("features.segmentAnythingEnabled")
}

object Datastore {
Expand Down Expand Up @@ -230,6 +231,10 @@ class WkConf @Inject()(configuration: Configuration) extends ConfigReader with L
val children = List(Loki)
}

object SegmentAnything {
val uri: String = get[String]("segmentAnything.uri")
}

val children =
List(
Http,
Expand All @@ -246,7 +251,8 @@ class WkConf @Inject()(configuration: Configuration) extends ConfigReader with L
GoogleAnalytics,
BackendAnalytics,
Slick,
Voxelytics
Voxelytics,
SegmentAnything
)

}
8 changes: 7 additions & 1 deletion conf/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ features {
# The Only valid item value is currently "ConnectomeView":
optInTabs = []
openIdConnectEnabled = false
segmentAnythingEnabled = false
}

# Serve annotations. Only active if the corresponding play module is enabled
Expand Down Expand Up @@ -289,10 +290,15 @@ voxelytics {
}
}

segmentAnything {
uri = "http://localhost:8080"
}

# Avoid creation of a pid file
pidfile.path = "/dev/null"


# # uncomment these lines for faster restart during local backend development (but beware the then-missing features):
# # Uncomment these lines for faster restart during local backend development (but beware the then-missing features):
# # Uncommenting these lines also means that a DB has to be set up manually for a dev deployment.
# slick.checkSchemaOnStartup = false
# play.modules.disabled += "play.modules.swagger.SwaggerModule"
5 changes: 5 additions & 0 deletions conf/messages
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,8 @@ folder.update.name.failed=Failed to update the folder’s name
folder.update.teams.failed=Failed to update the folder’s allowed teams
folder.create.failed.teams.failed=Failed to create folder in this location
folder.noWriteAccess=No write access in this folder

segmentAnything.notEnabled=AI based quick select is not enabled for this WEBKNOSSOS instance.
segmentAnything.noUri=No Uri for SAM server configured.
segmentAnything.getData.failed=Failed to get image data to send to SAM server.
segmentAnything.getEmbedding.failed=Failed to get image embedding from SAM server.
Loading

0 comments on commit 155e94e

Please sign in to comment.