Skip to content

Commit

Permalink
Extract text with AI command (#2191)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-zaitsev authored Dec 16, 2024
1 parent 073d218 commit 4848e9f
Show file tree
Hide file tree
Showing 8 changed files with 547 additions and 94 deletions.
73 changes: 66 additions & 7 deletions maestro-ai/src/main/java/maestro/ai/Prediction.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,36 @@ data class Defect(
)

@Serializable
private data class ModelResponse(
private data class AskForDefectsResponse(
val defects: List<Defect>,
)

@Serializable
private data class ExtractTextResponse(
val text: String?
)

object Prediction {

private val askForDefectsSchema by lazy {
readSchema("askForDefects")
}

private val extractTextSchema by lazy {
readSchema("extractText")
}

/**
* We use JSON mode/Structured Outputs to define the schema of the response we expect from the LLM.
* - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
* - Gemini: https://ai.google.dev/gemini-api/docs/json-mode
*/
private val askForDefectsSchema: String = run {
val resourceStream = this::class.java.getResourceAsStream("/askForDefects_schema.json")
?: throw IllegalStateException("Could not find askForDefects_schema.json in resources")
private fun readSchema(name: String): String {
val fileName = "/${name}_schema.json"
val resourceStream = this::class.java.getResourceAsStream(fileName)
?: throw IllegalStateException("Could not find $fileName in resources")

resourceStream.bufferedReader().use { it.readText() }
return resourceStream.bufferedReader().use { it.readText() }
}

private val json = Json { ignoreUnknownKeys = true }
Expand Down Expand Up @@ -126,7 +140,7 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val defects = json.decodeFromString<ModelResponse>(aiResponse.response)
val defects = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return defects.defects
}

Expand Down Expand Up @@ -208,7 +222,52 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val response = json.decodeFromString<ModelResponse>(aiResponse.response)
val response = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return response.defects.firstOrNull()
}

suspend fun extractText(
aiClient: AI,
screen: ByteArray,
query: String,
): String {
val prompt = buildString {
append("What text on the screen matches the following query: $query")

append(
"""
|
|RULES:
|* Provide response as a valid JSON, with structure described below.
""".trimMargin("|")
)

append(
"""
|
|* You must provide result as a valid JSON object, matching this structure:
|
| {
| "text": <string>
| }
|
|DO NOT output any other information in the JSON object.
""".trimMargin("|")
)
}

val aiResponse = aiClient.chatCompletion(
prompt,
model = aiClient.defaultModel,
maxTokens = 4096,
identifier = "perform-assertion",
imageDetail = "high",
images = listOf(screen),
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(extractTextSchema).jsonObject else null,
)

val response = json.decodeFromString<ExtractTextResponse>(aiResponse.response)
return response.text ?: ""
}

}
17 changes: 17 additions & 0 deletions maestro-ai/src/main/resources/extractText_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"name": "extractText",
"description": "Extracts text from an image based on a given query",
"strict": true,
"schema": {
"type": "object",
"required": [
"text"
],
"additionalProperties": false,
"properties": {
"text": {
"type": "string"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ sealed interface Command {
fun visible(): Boolean = true

val label: String?

val optional: Boolean
}

Expand Down Expand Up @@ -65,18 +65,23 @@ data class SwipeCommand(
label != null -> {
label
}

elementSelector != null && direction != null -> {
"Swiping in $direction direction on ${elementSelector.description()}"
}

direction != null -> {
"Swiping in $direction direction in $duration ms"
}

startPoint != null && endPoint != null -> {
"Swipe from (${startPoint.x},${startPoint.y}) to (${endPoint.x},${endPoint.y}) in $duration ms"
}

startRelative != null && endRelative != null -> {
"Swipe from ($startRelative) to ($endRelative) in $duration ms"
}

else -> "Invalid input to swipe command"
}
}
Expand Down Expand Up @@ -113,11 +118,15 @@ data class ScrollUntilVisibleCommand(

private fun String.speedToDuration(): String {
val duration = ((1000 * (100 - this.toLong()).toDouble() / 100).toLong() + 1)
return if (duration < 0) { DEFAULT_SCROLL_DURATION } else duration.toString()
return if (duration < 0) {
DEFAULT_SCROLL_DURATION
} else duration.toString()
}

private fun String.timeoutToMillis(): String {
return if (this.toLong() < 0) { DEFAULT_TIMEOUT_IN_MILLIS } else this
return if (this.toLong() < 0) {
DEFAULT_TIMEOUT_IN_MILLIS
} else this
}

override fun description(): String {
Expand Down Expand Up @@ -336,7 +345,7 @@ data class AssertCommand(
) : Command {

override fun description(): String {
if (label != null){
if (label != null) {
return label
}
val timeoutStr = timeout?.let { " within $timeout ms" } ?: ""
Expand Down Expand Up @@ -382,7 +391,8 @@ data class AssertConditionCommand(
}

override fun description(): String {
val optional = if (optional || condition.visible?.optional == true || condition.notVisible?.optional == true ) "(Optional) " else ""
val optional =
if (optional || condition.visible?.optional == true || condition.notVisible?.optional == true) "(Optional) " else ""
return label ?: "Assert that $optional${condition.description()}"
}

Expand Down Expand Up @@ -425,6 +435,25 @@ data class AssertWithAICommand(
}
}

data class ExtractTextWithAICommand(
val query: String,
val outputVariable: String,
override val optional: Boolean = true,
override val label: String? = null
) : Command {
override fun description(): String {
if (label != null) return label

return "Extract text with AI: $query"
}

override fun evaluateScripts(jsEngine: JsEngine): Command {
return copy(
query = query.evaluateScripts(jsEngine),
)
}
}

data class InputTextCommand(
val text: String,
override val label: String? = null,
Expand Down Expand Up @@ -454,7 +483,7 @@ data class LaunchAppCommand(
) : Command {

override fun description(): String {
if (label != null){
if (label != null) {
return label
}

Expand Down Expand Up @@ -782,12 +811,15 @@ data class RepeatCommand(
label != null -> {
label
}

condition != null && timesInt > 1 -> {
"Repeat while ${condition.description()} (up to $timesInt times)"
}

condition != null -> {
"Repeat while ${condition.description()}"
}

timesInt > 1 -> "Repeat $timesInt times"
else -> "Repeat indefinitely"
}
Expand Down Expand Up @@ -824,6 +856,7 @@ data class RetryCommand(
label != null -> {
label
}

else -> "Retry $maxAttempts times"
}
}
Expand Down Expand Up @@ -943,8 +976,8 @@ data class TravelCommand(
val dLon = Math.toRadians(aLon - oLon)

val a = Math.sin(dLat / 2) * Math.sin(dLat / 2) +
Math.cos(Math.toRadians(oLat)) * Math.cos(Math.toRadians(aLat)) *
Math.sin(dLon / 2) * Math.sin(dLon / 2)
Math.cos(Math.toRadians(oLat)) * Math.cos(Math.toRadians(aLat)) *
Math.sin(dLon / 2) * Math.sin(dLon / 2)

val c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
val distance = earthRadius * c * 1000 // convert to meters
Expand All @@ -960,7 +993,12 @@ data class TravelCommand(

override fun evaluateScripts(jsEngine: JsEngine): Command {
return copy(
points = points.map { it.copy(latitude = it.latitude.evaluateScripts(jsEngine), longitude = it.longitude.evaluateScripts(jsEngine)) }
points = points.map {
it.copy(
latitude = it.latitude.evaluateScripts(jsEngine),
longitude = it.longitude.evaluateScripts(jsEngine)
)
}
)
}

Expand All @@ -987,7 +1025,7 @@ data class AddMediaCommand(
val mediaPaths: List<String>,
override val label: String? = null,
override val optional: Boolean = false,
): Command {
) : Command {

override fun description(): String {
return label ?: "Adding media files(${mediaPaths.size}) to the device"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data class MaestroCommand(
val assertConditionCommand: AssertConditionCommand? = null,
val assertNoDefectsWithAICommand: AssertNoDefectsWithAICommand? = null,
val assertWithAICommand: AssertWithAICommand? = null,
val extractTextWithAICommand: ExtractTextWithAICommand? = null,
val inputTextCommand: InputTextCommand? = null,
val inputRandomTextCommand: InputRandomCommand? = null,
val launchAppCommand: LaunchAppCommand? = null,
Expand Down Expand Up @@ -82,6 +83,7 @@ data class MaestroCommand(
assertConditionCommand = command as? AssertConditionCommand,
assertNoDefectsWithAICommand = command as? AssertNoDefectsWithAICommand,
assertWithAICommand = command as? AssertWithAICommand,
extractTextWithAICommand = command as? ExtractTextWithAICommand,
inputTextCommand = command as? InputTextCommand,
inputRandomTextCommand = command as? InputRandomCommand,
launchAppCommand = command as? LaunchAppCommand,
Expand Down Expand Up @@ -125,6 +127,7 @@ data class MaestroCommand(
assertConditionCommand != null -> assertConditionCommand
assertNoDefectsWithAICommand != null -> assertNoDefectsWithAICommand
assertWithAICommand != null -> assertWithAICommand
extractTextWithAICommand != null -> extractTextWithAICommand
inputTextCommand != null -> inputTextCommand
inputRandomTextCommand != null -> inputRandomTextCommand
launchAppCommand != null -> launchAppCommand
Expand Down
Loading

0 comments on commit 4848e9f

Please sign in to comment.