Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract text with AI command #2191

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions maestro-ai/src/main/java/maestro/ai/Prediction.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,36 @@ data class Defect(
)

@Serializable
private data class ModelResponse(
private data class AskForDefectsResponse(
val defects: List<Defect>,
)

@Serializable
private data class ExtractTextResponse(
val text: String?
)

object Prediction {

private val askForDefectsSchema by lazy {
readSchema("askForDefects")
}

private val extractTextSchema by lazy {
readSchema("extractText")
}

/**
* We use JSON mode/Structured Outputs to define the schema of the response we expect from the LLM.
* - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
* - Gemini: https://ai.google.dev/gemini-api/docs/json-mode
*/
private val askForDefectsSchema: String = run {
val resourceStream = this::class.java.getResourceAsStream("/askForDefects_schema.json")
?: throw IllegalStateException("Could not find askForDefects_schema.json in resources")
private fun readSchema(name: String): String {
val fileName = "/${name}_schema.json"
val resourceStream = this::class.java.getResourceAsStream(fileName)
?: throw IllegalStateException("Could not find $fileName in resources")

resourceStream.bufferedReader().use { it.readText() }
return resourceStream.bufferedReader().use { it.readText() }
}

private val json = Json { ignoreUnknownKeys = true }
Expand Down Expand Up @@ -126,7 +140,7 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val defects = json.decodeFromString<ModelResponse>(aiResponse.response)
val defects = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return defects.defects
}

Expand Down Expand Up @@ -208,7 +222,52 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val response = json.decodeFromString<ModelResponse>(aiResponse.response)
val response = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return response.defects.firstOrNull()
}

suspend fun extractText(
aiClient: AI,
screen: ByteArray,
query: String,
): String {
val prompt = buildString {
append("What text on the screen matches the following query: $query")

append(
"""
|
|RULES:
|* Provide response as a valid JSON, with structure described below.
""".trimMargin("|")
)

append(
"""
|
|* You must provide result as a valid JSON object, matching this structure:
|
| {
| "text": <string>
| }
|
|DO NOT output any other information in the JSON object.
""".trimMargin("|")
)
}

val aiResponse = aiClient.chatCompletion(
prompt,
model = aiClient.defaultModel,
maxTokens = 4096,
identifier = "perform-assertion",
imageDetail = "high",
images = listOf(screen),
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(extractTextSchema).jsonObject else null,
)

val response = json.decodeFromString<ExtractTextResponse>(aiResponse.response)
return response.text ?: ""
}

}
17 changes: 17 additions & 0 deletions maestro-ai/src/main/resources/extractText_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"name": "extractText",
"description": "Extracts text from an image based on a given query",
"strict": true,
"schema": {
"type": "object",
"required": [
"text"
],
"additionalProperties": false,
"properties": {
"text": {
"type": "string"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ sealed interface Command {
fun visible(): Boolean = true

val label: String?

val optional: Boolean
}

Expand Down Expand Up @@ -65,18 +65,23 @@ data class SwipeCommand(
label != null -> {
label
}

elementSelector != null && direction != null -> {
"Swiping in $direction direction on ${elementSelector.description()}"
}

direction != null -> {
"Swiping in $direction direction in $duration ms"
}

startPoint != null && endPoint != null -> {
"Swipe from (${startPoint.x},${startPoint.y}) to (${endPoint.x},${endPoint.y}) in $duration ms"
}

startRelative != null && endRelative != null -> {
"Swipe from ($startRelative) to ($endRelative) in $duration ms"
}

else -> "Invalid input to swipe command"
}
}
Expand Down Expand Up @@ -113,11 +118,15 @@ data class ScrollUntilVisibleCommand(

private fun String.speedToDuration(): String {
val duration = ((1000 * (100 - this.toLong()).toDouble() / 100).toLong() + 1)
return if (duration < 0) { DEFAULT_SCROLL_DURATION } else duration.toString()
return if (duration < 0) {
DEFAULT_SCROLL_DURATION
} else duration.toString()
}

private fun String.timeoutToMillis(): String {
return if (this.toLong() < 0) { DEFAULT_TIMEOUT_IN_MILLIS } else this
return if (this.toLong() < 0) {
DEFAULT_TIMEOUT_IN_MILLIS
} else this
}

override fun description(): String {
Expand Down Expand Up @@ -336,7 +345,7 @@ data class AssertCommand(
) : Command {

override fun description(): String {
if (label != null){
if (label != null) {
return label
}
val timeoutStr = timeout?.let { " within $timeout ms" } ?: ""
Expand Down Expand Up @@ -382,7 +391,8 @@ data class AssertConditionCommand(
}

override fun description(): String {
val optional = if (optional || condition.visible?.optional == true || condition.notVisible?.optional == true ) "(Optional) " else ""
val optional =
if (optional || condition.visible?.optional == true || condition.notVisible?.optional == true) "(Optional) " else ""
return label ?: "Assert that $optional${condition.description()}"
}

Expand Down Expand Up @@ -425,6 +435,25 @@ data class AssertWithAICommand(
}
}

data class ExtractTextWithAICommand(
val query: String,
val outputVariable: String,
override val optional: Boolean = true,
override val label: String? = null
) : Command {
override fun description(): String {
if (label != null) return label

return "Extract text with AI: $query"
}

override fun evaluateScripts(jsEngine: JsEngine): Command {
return copy(
query = query.evaluateScripts(jsEngine),
)
}
}

data class InputTextCommand(
val text: String,
override val label: String? = null,
Expand Down Expand Up @@ -454,7 +483,7 @@ data class LaunchAppCommand(
) : Command {

override fun description(): String {
if (label != null){
if (label != null) {
return label
}

Expand Down Expand Up @@ -782,12 +811,15 @@ data class RepeatCommand(
label != null -> {
label
}

condition != null && timesInt > 1 -> {
"Repeat while ${condition.description()} (up to $timesInt times)"
}

condition != null -> {
"Repeat while ${condition.description()}"
}

timesInt > 1 -> "Repeat $timesInt times"
else -> "Repeat indefinitely"
}
Expand Down Expand Up @@ -824,6 +856,7 @@ data class RetryCommand(
label != null -> {
label
}

else -> "Retry $maxAttempts times"
}
}
Expand Down Expand Up @@ -943,8 +976,8 @@ data class TravelCommand(
val dLon = Math.toRadians(aLon - oLon)

val a = Math.sin(dLat / 2) * Math.sin(dLat / 2) +
Math.cos(Math.toRadians(oLat)) * Math.cos(Math.toRadians(aLat)) *
Math.sin(dLon / 2) * Math.sin(dLon / 2)
Math.cos(Math.toRadians(oLat)) * Math.cos(Math.toRadians(aLat)) *
Math.sin(dLon / 2) * Math.sin(dLon / 2)

val c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
val distance = earthRadius * c * 1000 // convert to meters
Expand All @@ -960,7 +993,12 @@ data class TravelCommand(

override fun evaluateScripts(jsEngine: JsEngine): Command {
return copy(
points = points.map { it.copy(latitude = it.latitude.evaluateScripts(jsEngine), longitude = it.longitude.evaluateScripts(jsEngine)) }
points = points.map {
it.copy(
latitude = it.latitude.evaluateScripts(jsEngine),
longitude = it.longitude.evaluateScripts(jsEngine)
)
}
)
}

Expand All @@ -987,7 +1025,7 @@ data class AddMediaCommand(
val mediaPaths: List<String>,
override val label: String? = null,
override val optional: Boolean = false,
): Command {
) : Command {

override fun description(): String {
return label ?: "Adding media files(${mediaPaths.size}) to the device"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data class MaestroCommand(
val assertConditionCommand: AssertConditionCommand? = null,
val assertNoDefectsWithAICommand: AssertNoDefectsWithAICommand? = null,
val assertWithAICommand: AssertWithAICommand? = null,
val extractTextWithAICommand: ExtractTextWithAICommand? = null,
val inputTextCommand: InputTextCommand? = null,
val inputRandomTextCommand: InputRandomCommand? = null,
val launchAppCommand: LaunchAppCommand? = null,
Expand Down Expand Up @@ -82,6 +83,7 @@ data class MaestroCommand(
assertConditionCommand = command as? AssertConditionCommand,
assertNoDefectsWithAICommand = command as? AssertNoDefectsWithAICommand,
assertWithAICommand = command as? AssertWithAICommand,
extractTextWithAICommand = command as? ExtractTextWithAICommand,
inputTextCommand = command as? InputTextCommand,
inputRandomTextCommand = command as? InputRandomCommand,
launchAppCommand = command as? LaunchAppCommand,
Expand Down Expand Up @@ -125,6 +127,7 @@ data class MaestroCommand(
assertConditionCommand != null -> assertConditionCommand
assertNoDefectsWithAICommand != null -> assertNoDefectsWithAICommand
assertWithAICommand != null -> assertWithAICommand
extractTextWithAICommand != null -> extractTextWithAICommand
inputTextCommand != null -> inputTextCommand
inputRandomTextCommand != null -> inputRandomTextCommand
launchAppCommand != null -> launchAppCommand
Expand Down
Loading
Loading