Skip to content

Commit

Permalink
Add NNAPI execution provider (#420) (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
ermolenkodev authored Aug 31, 2022
1 parent 2570608 commit e9248c5
Showing 1 changed file with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders

import ai.onnxruntime.OrtProvider
import ai.onnxruntime.OrtSession
import ai.onnxruntime.providers.NNAPIFlags
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CUDA
import java.util.EnumSet

/**
* These are classes representing the supported ONNXRuntime execution providers for KotlinDL.
* The supported providers are:
* - [CPU] (default)
* - [CUDA] (could be used if the CUDA runtime is installed)
* - [NNAPI] (could be used on Android if the NNAPI runtime is supported)
*
* Internally, the [OrtProvider] enum is used to indicate the provider.
*/
public sealed class ExecutionProvider(public val internalProviderId: OrtProvider) {
/**
* Default CPU execution provider.
* Available on all platforms.
*
* @param useBFCArenaAllocator If true, the CPU provider will use BFC arena allocator.
* @see [OrtProvider.CPU]
Expand All @@ -27,6 +32,7 @@ public sealed class ExecutionProvider(public val internalProviderId: OrtProvider

/**
* CUDA execution provider.
* Available only on platforms with Nvidia gpu and CUDA runtime installed.
*
* @param deviceId The device ID to use.
* @see [OrtProvider.CUDA]
Expand All @@ -37,6 +43,22 @@ public sealed class ExecutionProvider(public val internalProviderId: OrtProvider
}
}

/**
* NNAPI execution provider.
* Available only on Android.
*
* @param flags An NNAPI flags to modify the behavior of the NNAPI execution provider.
* @see [OrtProvider.NNAPI]
* @see <a href=https://onnxruntime.ai/docs/execution-providers/NNAPI-ExecutionProvider.html>NNAPI documentation</a>.
*/
public data class NNAPI(public val flags: Set<NNAPIFlags> = emptySet()) : ExecutionProvider(OrtProvider.NNAPI) {
override fun addOptionsTo(sessionOptions: OrtSession.SessionOptions) {
val internalNNAPIFlags = EnumSet.noneOf(NNAPIFlags::class.java)
flags.let { internalNNAPIFlags.addAll(it) }
sessionOptions.addNnapi(internalNNAPIFlags)
}
}

/**
* Adds execution provider options to the [OrtSession.SessionOptions].
*/
Expand Down

0 comments on commit e9248c5

Please sign in to comment.