diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt index 24540dcc3dc..32004915256 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt @@ -30,7 +30,6 @@ import java.time.Duration import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking -import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder @@ -339,12 +338,7 @@ class HonestMajorityShareShuffleMillTest { DUCHY_THREE_NAME to workerStub, ) - private lateinit var aggregatorMill: HonestMajorityShareShuffleMill - private lateinit var firstWorkerMill: HonestMajorityShareShuffleMill - private lateinit var secondWorkerMill: HonestMajorityShareShuffleMill - - @Before - fun initializeMill() = runBlocking { + private fun createHmssMill(duchyName: String): HonestMajorityShareShuffleMill { DuchyInfo.setForTest(setOf(DUCHY_ONE_NAME, DUCHY_TWO_NAME, DUCHY_THREE_NAME)) val csCertificate = Certificate(DUCHY_CERT_NAME, DUCHY_SIGNING_CERT) @@ -356,69 +350,26 @@ class HonestMajorityShareShuffleMillTest { val trustedCertificates = duchyCertificates + dataProviderCertificates - firstWorkerMill = - HonestMajorityShareShuffleMill( - millId = DUCHY_ONE_NAME + MILL_ID_SUFFIX, - duchyId = DUCHY_ONE_NAME, - signingKey = DUCHY_SIGNING_KEY, - consentSignalCert = csCertificate, - trustedCertificates = trustedCertificates, - dataClients = computationDataClients, - systemComputationParticipantsClient = systemComputationParticipantsStub, - systemComputationsClient = systemComputationStub, - systemComputationLogEntriesClient = systemComputationLogEntriesStub, - computationStatsClient = computationStatsStub, - privateKeyStore = privateKeyStore, - certificateClient = certificateStub, - workerStubs = workerStubs, - cryptoWorker = mockCryptoWorker, - workLockDuration = Duration.ofMinutes(5), - openTelemetry = GlobalOpenTelemetry.get(), - requestChunkSizeBytes = 20, - maximumAttempts = 2, - ) - secondWorkerMill = - HonestMajorityShareShuffleMill( - millId = DUCHY_TWO_NAME + MILL_ID_SUFFIX, - duchyId = DUCHY_TWO_NAME, - signingKey = DUCHY_SIGNING_KEY, - consentSignalCert = csCertificate, - trustedCertificates = trustedCertificates, - dataClients = computationDataClients, - systemComputationParticipantsClient = systemComputationParticipantsStub, - systemComputationsClient = systemComputationStub, - systemComputationLogEntriesClient = systemComputationLogEntriesStub, - computationStatsClient = computationStatsStub, - privateKeyStore = privateKeyStore, - certificateClient = certificateStub, - workerStubs = workerStubs, - cryptoWorker = mockCryptoWorker, - workLockDuration = Duration.ofMinutes(5), - openTelemetry = GlobalOpenTelemetry.get(), - requestChunkSizeBytes = 20, - maximumAttempts = 2, - ) - aggregatorMill = - HonestMajorityShareShuffleMill( - millId = DUCHY_THREE_NAME + MILL_ID_SUFFIX, - duchyId = DUCHY_THREE_NAME, - signingKey = DUCHY_SIGNING_KEY, - consentSignalCert = csCertificate, - trustedCertificates = trustedCertificates, - dataClients = computationDataClients, - systemComputationParticipantsClient = systemComputationParticipantsStub, - systemComputationsClient = systemComputationStub, - systemComputationLogEntriesClient = systemComputationLogEntriesStub, - computationStatsClient = computationStatsStub, - privateKeyStore = privateKeyStore, - certificateClient = certificateStub, - workerStubs = workerStubs, - cryptoWorker = mockCryptoWorker, - workLockDuration = Duration.ofMinutes(5), - openTelemetry = GlobalOpenTelemetry.get(), - requestChunkSizeBytes = 20, - maximumAttempts = 2, - ) + return HonestMajorityShareShuffleMill( + millId = duchyName + MILL_ID_SUFFIX, + duchyId = duchyName, + signingKey = DUCHY_SIGNING_KEY, + consentSignalCert = csCertificate, + trustedCertificates = trustedCertificates, + dataClients = computationDataClients, + systemComputationParticipantsClient = systemComputationParticipantsStub, + systemComputationsClient = systemComputationStub, + systemComputationLogEntriesClient = systemComputationLogEntriesStub, + computationStatsClient = computationStatsStub, + privateKeyStore = privateKeyStore, + certificateClient = certificateStub, + workerStubs = workerStubs, + cryptoWorker = mockCryptoWorker, + workLockDuration = Duration.ofMinutes(5), + openTelemetry = GlobalOpenTelemetry.get(), + requestChunkSizeBytes = 20, + maximumAttempts = 2, + ) } private suspend fun getHmssComputationDetails(role: RoleInComputation): ComputationDetails { @@ -467,7 +418,8 @@ class HonestMajorityShareShuffleMillTest { requisitions = REQUISITIONS, ) - firstWorkerMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_ONE_NAME) + mill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( @@ -521,7 +473,8 @@ class HonestMajorityShareShuffleMillTest { requisitions = REQUISITIONS, ) - firstWorkerMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_ONE_NAME) + mill.pollAndProcessNextComputation() val updatedToken = fakeComputationDb[LOCAL_ID] @@ -602,7 +555,8 @@ class HonestMajorityShareShuffleMillTest { requisitions = requisitions, ) - secondWorkerMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_TWO_NAME) + mill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( @@ -710,7 +664,8 @@ class HonestMajorityShareShuffleMillTest { } } - firstWorkerMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_ONE_NAME) + mill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( @@ -806,7 +761,8 @@ class HonestMajorityShareShuffleMillTest { whenever(mockCertificates.getCertificate(any())) .thenThrow(Status.NOT_FOUND.asRuntimeException()) - firstWorkerMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_ONE_NAME) + mill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( @@ -869,7 +825,8 @@ class HonestMajorityShareShuffleMillTest { } } - aggregatorMill.pollAndProcessNextComputation() + val mill = createHmssMill(DUCHY_THREE_NAME) + mill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo(