Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Measurement ConsumerId in EdpSimulator #1067

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ import org.wfanet.measurement.api.v2alpha.LiquidLegionsSketchParams
import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt.filter
import org.wfanet.measurement.api.v2alpha.Measurement
import org.wfanet.measurement.api.v2alpha.MeasurementConsumer
import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementKey
import org.wfanet.measurement.api.v2alpha.MeasurementKt
import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.frequency
import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.impression
Expand Down Expand Up @@ -382,13 +384,19 @@ class EdpSimulator(
/** Executes the requisition fulfillment workflow. */
suspend fun executeRequisitionFulfillingWorkflow() {
logger.info("Executing requisitionFulfillingWorkflow...")
val requisitions = getRequisitions()
val requisitions =
getRequisitions().filter {
MeasurementKey.fromName(it.measurement)!!.measurementConsumerId ==
MeasurementConsumerKey.fromName(measurementConsumerName)!!.measurementConsumerId
}

if (requisitions.isEmpty()) {
logger.fine("No unfulfilled requisition. Polling again later...")
return
}

for (requisition in requisitions) {
println("requisitionrequisitionrequisitionrequisitionrequisition ${requisition.measurement}")
logger.info("Processing requisition ${requisition.name}...")
val measurementConsumerCertificate: Certificate =
getCertificate(requisition.measurementConsumerCertificate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import java.time.LocalDate
import java.time.ZoneOffset
import kotlin.random.Random
import kotlinx.coroutines.runBlocking
import org.junit.BeforeClass
import org.junit.ClassRule
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
Expand Down Expand Up @@ -152,7 +151,8 @@ import org.wfanet.measurement.loadtest.storage.SketchStore
import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient

private const val TEMPLATE_PREFIX = "wfa.measurement.api.v2alpha.event_templates.testing"
private const val MC_NAME = "mc"
private const val MC_ID = "mc"
private const val MC_NAME = "measurementConsumers/$MC_ID"
private val EVENT_TEMPLATES =
listOf("$TEMPLATE_PREFIX.Video", "$TEMPLATE_PREFIX.BannerAd", "$TEMPLATE_PREFIX.Person")
private const val EDP_DISPLAY_NAME = "edp1"
Expand Down Expand Up @@ -195,6 +195,7 @@ private val SKETCH_CONFIG = sketchConfig {
private val MEASUREMENT_CONSUMER_CERTIFICATE_DER =
SECRET_FILES_PATH.resolve("mc_cs_cert.der").toFile().readByteString()
private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs"
private const val MEASUREMENT_NAME = "$MC_NAME/measurements/BBBBBBBBBHs"
private const val MEASUREMENT_CONSUMER_CERTIFICATE_NAME =
"$MEASUREMENT_CONSUMER_NAME/certificates/AAAAAAAAAcg"
private val MEASUREMENT_CONSUMER_CERTIFICATE = certificate {
Expand All @@ -204,7 +205,7 @@ private val MEASUREMENT_CONSUMER_CERTIFICATE = certificate {
private val MEASUREMENT_PUBLIC_KEY =
encryptionPublicKey {
format = EncryptionPublicKey.Format.TINK_KEYSET
data = SECRET_FILES_PATH.resolve("${MC_NAME}_enc_public.tink").toFile().readByteString()
data = SECRET_FILES_PATH.resolve("${MC_ID}_enc_public.tink").toFile().readByteString()
}
.toByteString()

Expand Down Expand Up @@ -279,6 +280,10 @@ class EdpSimulatorTest {
private val requisitionFulfillmentServiceMock: RequisitionFulfillmentCoroutineImplBase =
mockService()

@get:Rule public val temporaryFolder: TemporaryFolder = TemporaryFolder()

private lateinit var sketchStore: SketchStore

@get:Rule
val grpcTestServerRule = GrpcTestServerRule {
addService(measurementConsumersServiceMock)
Expand All @@ -289,6 +294,11 @@ class EdpSimulatorTest {
addService(requisitionFulfillmentServiceMock)
}

@Before
fun setup() {
sketchStore = SketchStore(FileSystemStorageClient(temporaryFolder.root))
}

private val measurementConsumersStub by lazy {
MeasurementConsumersCoroutineStub(grpcTestServerRule.channel)
}
Expand Down Expand Up @@ -355,6 +365,69 @@ class EdpSimulatorTest {
}
}

@Test
fun `Does nothing for requisitions with different Measumrent Consumer Id`() {
runBlocking {
val allEvents =
generateEvents(
1L..10L,
FIRST_EVENT_DATE,
Person.AgeGroup.YEARS_18_TO_34,
Person.Gender.FEMALE
) +
generateEvents(
11L..15L,
FIRST_EVENT_DATE,
Person.AgeGroup.YEARS_35_TO_54,
Person.Gender.FEMALE
) +
generateEvents(
16L..20L,
FIRST_EVENT_DATE,
Person.AgeGroup.YEARS_55_PLUS,
Person.Gender.FEMALE
) +
generateEvents(
21L..25L,
FIRST_EVENT_DATE,
Person.AgeGroup.YEARS_18_TO_34,
Person.Gender.MALE
) +
generateEvents(
26L..30L,
FIRST_EVENT_DATE,
Person.AgeGroup.YEARS_35_TO_54,
Person.Gender.MALE
)

val random = java.util.Random()

val edpSimulator =
EdpSimulator(
EDP_DATA,
"measurementConsumers/differentMcId",
measurementConsumersStub,
certificatesStub,
eventGroupsStub,
eventGroupMetadataDescriptorsStub,
requisitionsStub,
requisitionFulfillmentStub,
sketchStore,
InMemoryEventQuery(allEvents),
MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofMillis(1000)),
EVENT_TEMPLATES,
privacyBudgetManager,
TRUSTED_CERTIFICATES,
random,
DIRECT_NOISE_MECHANISM
)
edpSimulator.createEventGroup()
edpSimulator.executeRequisitionFulfillingWorkflow()
val storedSketch = sketchStore.get(REQUISITION_ONE)?.read()?.flatten()
assertThat(storedSketch).isNull()
}
}

@Test
fun `filters events, charges privacy budget and generates sketch successfully`() {
runBlocking {
Expand Down Expand Up @@ -836,8 +909,7 @@ class EdpSimulatorTest {
}

companion object {
private val MC_SIGNING_KEY =
loadSigningKey("${MC_NAME}_cs_cert.der", "${MC_NAME}_cs_private.der")
private val MC_SIGNING_KEY = loadSigningKey("${MC_ID}_cs_cert.der", "${MC_ID}_cs_private.der")
private val DUCHY_SIGNING_KEY =
loadSigningKey("${DUCHY_ID}_cs_cert.der", "${DUCHY_ID}_cs_private.der")

Expand Down Expand Up @@ -884,6 +956,7 @@ class EdpSimulatorTest {

private val REQUISITION_ONE = requisition {
name = "requisition_one"
measurement = MEASUREMENT_NAME
state = Requisition.State.UNFULFILLED
measurementConsumerCertificate = MEASUREMENT_CONSUMER_CERTIFICATE_NAME
measurementSpec = signMeasurementSpec(MEASUREMENT_SPEC, MC_SIGNING_KEY)
Expand Down Expand Up @@ -919,8 +992,6 @@ class EdpSimulatorTest {
readCertificateCollection(SECRET_FILES_PATH.resolve("edp_trusted_certs.pem").toFile())
.associateBy { requireNotNull(it.authorityKeyIdentifier) }

@JvmField @ClassRule val temporaryFolder: TemporaryFolder = TemporaryFolder()

private fun loadSigningKey(
certDerFileName: String,
privateKeyDerFileName: String
Expand Down Expand Up @@ -950,13 +1021,5 @@ class EdpSimulatorTest {
private fun assertAnySketchEquals(sketch: AnySketch, other: AnySketch) {
assertThat(sketch).comparingElementsUsing(EQUIVALENCE).containsExactlyElementsIn(other)
}

private lateinit var sketchStore: SketchStore

@JvmStatic
@BeforeClass
fun initSketchStore() {
sketchStore = SketchStore(FileSystemStorageClient(temporaryFolder.root))
}
}
}