Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cross-media-measurement-api version #1123

Merged
merged 6 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ load(
"K8S_CLIENT_VERSION",
"OPEN_TELEMETRY_SDK_VERSION",
)

load("//build:repositories.bzl", "wfa_measurement_system_repositories")

wfa_measurement_system_repositories()
Expand Down
4 changes: 2 additions & 2 deletions build/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def wfa_measurement_system_repositories():
wfa_repo_archive(
name = "wfa_measurement_proto",
repo = "cross-media-measurement-api",
sha256 = "22f32f247c95d5c6efab8b00ecf3019268f293caf5065e1e0ab738419ad3c1d0",
version = "0.38.1",
sha256 = "cc327047bc094768c46a45b6e7a1cde3d0dfc3a89585f316932f0abbf78d2612",
version = "0.39.0",
)

wfa_repo_archive(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ kt_jvm_library(
"//src/main/proto/wfa/measurement/api/v2alpha:api_keys_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:data_providers_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:date_interval_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:model_lines_service_kt_jvm_grpc_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.google.crypto.tink.BinaryKeysetReader
import com.google.crypto.tink.CleartextKeysetHandle
import com.google.protobuf.ByteString
import com.google.protobuf.kotlin.toByteString
import com.google.type.date
import com.google.type.interval
import io.grpc.ManagedChannel
import java.io.File
Expand All @@ -30,6 +31,8 @@ import java.security.cert.X509Certificate
import java.time.Clock
import java.time.Duration as systemDuration
import java.time.Instant
import java.time.ZoneId
import java.util.*
import kotlin.properties.Delegates
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
Expand Down Expand Up @@ -93,6 +96,7 @@ import org.wfanet.measurement.api.v2alpha.createModelReleaseRequest
import org.wfanet.measurement.api.v2alpha.createModelRolloutRequest
import org.wfanet.measurement.api.v2alpha.createModelShardRequest
import org.wfanet.measurement.api.v2alpha.createModelSuiteRequest
import org.wfanet.measurement.api.v2alpha.dateInterval
import org.wfanet.measurement.api.v2alpha.deleteModelOutageRequest
import org.wfanet.measurement.api.v2alpha.deleteModelRolloutRequest
import org.wfanet.measurement.api.v2alpha.deleteModelShardRequest
Expand Down Expand Up @@ -1770,23 +1774,23 @@ private class ModelRollouts {
)
modelLineName: String,
@Option(
names = ["--rollout-start-time"],
description = ["Start time of model rollout in ISO 8601 format of UTC"],
names = ["--rollout-start-date"],
description = ["Start date of model rollout in ISO 8601 format of UTC"],
required = false,
)
rolloutStartTime: Instant? = null,
rolloutStartDate: Instant? = null,
@Option(
names = ["--rollout-end-time"],
description = ["End time of model rollout in ISO 8601 format of UTC"],
names = ["--rollout-end-date"],
description = ["End date of model rollout in ISO 8601 format of UTC"],
required = false,
)
rolloutEndTime: Instant? = null,
rolloutEndDate: Instant? = null,
@Option(
names = ["--instant-rollout-time"],
description = ["Instant rollout time of model rollout in ISO 8601 format of UTC"],
names = ["--instant-rollout-date"],
description = ["Instant rollout date of model rollout in ISO 8601 format of UTC"],
required = false,
)
instantRolloutTime: Instant? = null,
instantRolloutDate: Instant? = null,
@Option(
names = ["--model-release"],
description = ["The `ModelRelease` this model rollout refers to."],
Expand All @@ -1795,22 +1799,38 @@ private class ModelRollouts {
modelRolloutRelease: String,
) {

if (instantRolloutTime == null && (rolloutStartTime == null || rolloutEndTime == null)) {
if (instantRolloutDate == null && (rolloutStartDate == null || rolloutEndDate == null)) {
throw ParameterException(
parentCommand.commandLine,
"Both `rolloutStartTime` and `rolloutEndTime` must be set when `instantRolloutTime` is not."
"Both `rolloutStartDate` and `rolloutEndDate` must be set when `instantRolloutDate` is not."
)
}

val request = createModelRolloutRequest {
parent = modelLineName
modelRollout = modelRollout {
if (instantRolloutTime != null) {
this.instantRolloutTime = instantRolloutTime.toProtoTime()
if (instantRolloutDate != null) {
val instantRolloutDate = instantRolloutDate.atZone(ZoneId.of("UTC")).toLocalDate()
this.instantRolloutDate = date {
year = instantRolloutDate.year
month = instantRolloutDate.monthValue
day = instantRolloutDate.dayOfMonth
}
} else {
gradualRolloutPeriod = interval {
startTime = rolloutStartTime!!.toProtoTime()
endTime = rolloutEndTime!!.toProtoTime()
// TODO(@MarcoPremier): Move Instant to google.type.Date conversion to common.jvm
val startDate = rolloutStartDate!!.atZone(ZoneId.of("UTC")).toLocalDate()
val endDate = rolloutEndDate!!.atZone(ZoneId.of("UTC")).toLocalDate()
gradualRolloutPeriod = dateInterval {
this.startDate = date {
year = startDate.year
month = startDate.monthValue
day = startDate.dayOfMonth
}
this.endDate = date {
year = endDate.year
month = endDate.monthValue
day = endDate.dayOfMonth
}
}
}
modelRelease = modelRolloutRelease
Expand Down Expand Up @@ -1854,25 +1874,35 @@ private class ModelRollouts {
["Start time of overlapping period for desired model rollouts in ISO 8601 format of UTC"],
required = false,
)
rolloutPeriodStartTime: Instant? = null,
rolloutPeriodStartDate: Instant? = null,
@Option(
names = ["--rollout-period-overlapping-end-time"],
description =
["End time of overlapping period for desired model rollouts in ISO 8601 format of UTC"],
required = false,
)
rolloutPeriodEndTime: Instant? = null,
rolloutPeriodEndDate: Instant? = null,
) {
val request = listModelRolloutsRequest {
parent = modelLineName
pageSize = listPageSize
pageToken = listPageToken
if (rolloutPeriodStartTime != null && rolloutPeriodEndTime != null) {
if (rolloutPeriodStartDate != null && rolloutPeriodEndDate != null) {
filter =
ListModelRolloutsRequestKt.filter {
rolloutPeriodOverlapping = interval {
startTime = rolloutPeriodStartTime.toProtoTime()
endTime = rolloutPeriodEndTime.toProtoTime()
val startDate = rolloutPeriodStartDate.atZone(ZoneId.of("UTC")).toLocalDate()
val endDate = rolloutPeriodEndDate.atZone(ZoneId.of("UTC")).toLocalDate()
rolloutPeriodOverlapping = dateInterval {
this.startDate = date {
year = startDate.year
month = startDate.monthValue
day = startDate.dayOfMonth
}
this.endDate = date {
year = endDate.year
month = endDate.monthValue
day = endDate.dayOfMonth
}
}
}
}
Expand All @@ -1895,19 +1925,24 @@ private class ModelRollouts {
description = ["The rollout freeze time to be set in ISO 8601 format of UTC."],
required = true,
)
freezeTime: Instant,
freezeDate: Instant,
) {
val rolloutFreezeDate = freezeDate.atZone(ZoneId.of("UTC")).toLocalDate()
val request = scheduleModelRolloutFreezeRequest {
name = modelRolloutName
rolloutFreezeTime = freezeTime.toProtoTime()
this.rolloutFreezeDate = date {
year = rolloutFreezeDate.year
month = rolloutFreezeDate.monthValue
day = rolloutFreezeDate.dayOfMonth
}
}
val outputModelRollout =
runBlocking(parentCommand.rpcDispatcher) {
modelRolloutStub.scheduleModelRolloutFreeze(request)
}

println(
"Freeze time ${outputModelRollout.rolloutFreezeTime} has been set for ${outputModelRollout.name}."
"Freeze date ${outputModelRollout.rolloutFreezeDate} has been set for ${outputModelRollout.name}."
)
}

Expand All @@ -1928,12 +1963,12 @@ private class ModelRollouts {

private fun printModelRollout(modelRollout: ModelRollout) {
println("NAME - ${modelRollout.name}")
if (modelRollout.hasInstantRolloutTime()) {
println("INSTANT ROLLOUT TIME- ${modelRollout.instantRolloutTime}")
if (modelRollout.hasInstantRolloutDate()) {
println("INSTANT ROLLOUT DATE- ${modelRollout.instantRolloutDate}")
} else {
println("GRADUAL ROLLOUT PERIOD- ${modelRollout.gradualRolloutPeriod}")
}
println("ROLLOUT FREEZE TIME - ${modelRollout.rolloutFreezeTime}")
println("ROLLOUT FREEZE DATE - ${modelRollout.rolloutFreezeDate}")
println("PREVIOUS MODEL ROLLOUT - ${modelRollout.previousModelRollout}")
println("MODEL RELEASE - ${modelRollout.modelRelease}")
println("CREATE TIME - ${modelRollout.createTime}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor",
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key",
"//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:date_interval_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:model_rollout_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:model_rollouts_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:page_token_kt_jvm_proto",
Expand Down Expand Up @@ -436,6 +437,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config",
"//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:date_interval_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:differential_privacy_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:exchange_kt_jvm_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
package org.wfanet.measurement.kingdom.service.api.v2alpha

import com.google.protobuf.Empty
import com.google.protobuf.util.Timestamps
import com.google.protobuf.Timestamp
import com.google.protobuf.timestamp
import com.google.type.Date
import com.google.type.date
import com.google.type.interval
import io.grpc.Status
import io.grpc.StatusException
import java.time.Instant
import java.time.LocalDate
import java.time.ZoneOffset
import kotlin.math.min
import kotlinx.coroutines.flow.toList
import org.wfanet.measurement.api.v2alpha.CreateModelRolloutRequest
Expand Down Expand Up @@ -145,7 +152,7 @@ class ModelRolloutsService(private val internalClient: ModelRolloutsCoroutineStu
externalModelSuiteId = apiIdToExternalId(key.modelSuiteId)
externalModelProviderId = apiIdToExternalId(key.modelProviderId)
externalModelRolloutId = apiIdToExternalId(key.modelRolloutId)
rolloutFreezeTime = request.rolloutFreezeTime
rolloutFreezeTime = request.rolloutFreezeDate.toProtoTime()
}

try {
Expand Down Expand Up @@ -298,14 +305,10 @@ class ModelRolloutsService(private val internalClient: ModelRolloutsCoroutineStu
grpcRequire(
source.hasFilter() &&
source.filter.hasRolloutPeriodOverlapping() &&
Timestamps.compare(
source.filter.rolloutPeriodOverlapping.startTime,
this.rolloutPeriodOverlapping.startTime
) == 0 &&
Timestamps.compare(
source.filter.rolloutPeriodOverlapping.endTime,
this.rolloutPeriodOverlapping.endTime
) == 0
source.filter.rolloutPeriodOverlapping.startDate ==
this.rolloutPeriodOverlapping.startTime.toDate() &&
source.filter.rolloutPeriodOverlapping.endDate ==
this.rolloutPeriodOverlapping.endTime.toDate()
) {
"Arguments must be kept the same when using a page token"
}
Expand All @@ -327,7 +330,10 @@ class ModelRolloutsService(private val internalClient: ModelRolloutsCoroutineStu
this.externalModelSuiteId = externalModelSuiteId
this.externalModelLineId = externalModelLineId
if (source.hasFilter() && source.filter.hasRolloutPeriodOverlapping()) {
this.rolloutPeriodOverlapping = source.filter.rolloutPeriodOverlapping
this.rolloutPeriodOverlapping = interval {
startTime = source.filter.rolloutPeriodOverlapping.startDate.toProtoTime()
endTime = source.filter.rolloutPeriodOverlapping.endDate.toProtoTime()
}
}
}
}
Expand Down Expand Up @@ -364,4 +370,25 @@ class ModelRolloutsService(private val internalClient: ModelRolloutsCoroutineStu
}
}
}

// TODO(@MarcoPremier): Move this function to common-jvm.
private fun Timestamp.toDate(): Date {
val instant = Instant.ofEpochSecond(seconds, nanos.toLong())
val localDate = instant.atZone(ZoneOffset.UTC).toLocalDate()
return date {
year = localDate.year
month = localDate.monthValue
day = localDate.dayOfMonth
}
}

// TODO(@MarcoPremier): Move this function to common-jvm.
private fun Date.toProtoTime(): Timestamp {
val localDate = LocalDate.of(year, month, day)
val instant = localDate.atStartOfDay().toInstant(java.time.ZoneOffset.UTC)
return timestamp {
seconds = instant.epochSecond
nanos = instant.nano
}
}
}
Loading