Skip to content

Commit

Permalink
Update based on PR comments. #4
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh committed Mar 27, 2024
1 parent c692f3a commit afc4972
Showing 1 changed file with 33 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import java.time.Duration
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
Expand Down Expand Up @@ -339,12 +338,7 @@ class HonestMajorityShareShuffleMillTest {
DUCHY_THREE_NAME to workerStub,
)

private lateinit var aggregatorMill: HonestMajorityShareShuffleMill
private lateinit var firstWorkerMill: HonestMajorityShareShuffleMill
private lateinit var secondWorkerMill: HonestMajorityShareShuffleMill

@Before
fun initializeMill() = runBlocking {
private fun createHmssMill(duchyName: String): HonestMajorityShareShuffleMill {
DuchyInfo.setForTest(setOf(DUCHY_ONE_NAME, DUCHY_TWO_NAME, DUCHY_THREE_NAME))

val csCertificate = Certificate(DUCHY_CERT_NAME, DUCHY_SIGNING_CERT)
Expand All @@ -356,69 +350,26 @@ class HonestMajorityShareShuffleMillTest {

val trustedCertificates = duchyCertificates + dataProviderCertificates

firstWorkerMill =
HonestMajorityShareShuffleMill(
millId = DUCHY_ONE_NAME + MILL_ID_SUFFIX,
duchyId = DUCHY_ONE_NAME,
signingKey = DUCHY_SIGNING_KEY,
consentSignalCert = csCertificate,
trustedCertificates = trustedCertificates,
dataClients = computationDataClients,
systemComputationParticipantsClient = systemComputationParticipantsStub,
systemComputationsClient = systemComputationStub,
systemComputationLogEntriesClient = systemComputationLogEntriesStub,
computationStatsClient = computationStatsStub,
privateKeyStore = privateKeyStore,
certificateClient = certificateStub,
workerStubs = workerStubs,
cryptoWorker = mockCryptoWorker,
workLockDuration = Duration.ofMinutes(5),
openTelemetry = GlobalOpenTelemetry.get(),
requestChunkSizeBytes = 20,
maximumAttempts = 2,
)
secondWorkerMill =
HonestMajorityShareShuffleMill(
millId = DUCHY_TWO_NAME + MILL_ID_SUFFIX,
duchyId = DUCHY_TWO_NAME,
signingKey = DUCHY_SIGNING_KEY,
consentSignalCert = csCertificate,
trustedCertificates = trustedCertificates,
dataClients = computationDataClients,
systemComputationParticipantsClient = systemComputationParticipantsStub,
systemComputationsClient = systemComputationStub,
systemComputationLogEntriesClient = systemComputationLogEntriesStub,
computationStatsClient = computationStatsStub,
privateKeyStore = privateKeyStore,
certificateClient = certificateStub,
workerStubs = workerStubs,
cryptoWorker = mockCryptoWorker,
workLockDuration = Duration.ofMinutes(5),
openTelemetry = GlobalOpenTelemetry.get(),
requestChunkSizeBytes = 20,
maximumAttempts = 2,
)
aggregatorMill =
HonestMajorityShareShuffleMill(
millId = DUCHY_THREE_NAME + MILL_ID_SUFFIX,
duchyId = DUCHY_THREE_NAME,
signingKey = DUCHY_SIGNING_KEY,
consentSignalCert = csCertificate,
trustedCertificates = trustedCertificates,
dataClients = computationDataClients,
systemComputationParticipantsClient = systemComputationParticipantsStub,
systemComputationsClient = systemComputationStub,
systemComputationLogEntriesClient = systemComputationLogEntriesStub,
computationStatsClient = computationStatsStub,
privateKeyStore = privateKeyStore,
certificateClient = certificateStub,
workerStubs = workerStubs,
cryptoWorker = mockCryptoWorker,
workLockDuration = Duration.ofMinutes(5),
openTelemetry = GlobalOpenTelemetry.get(),
requestChunkSizeBytes = 20,
maximumAttempts = 2,
)
return HonestMajorityShareShuffleMill(
millId = duchyName + MILL_ID_SUFFIX,
duchyId = duchyName,
signingKey = DUCHY_SIGNING_KEY,
consentSignalCert = csCertificate,
trustedCertificates = trustedCertificates,
dataClients = computationDataClients,
systemComputationParticipantsClient = systemComputationParticipantsStub,
systemComputationsClient = systemComputationStub,
systemComputationLogEntriesClient = systemComputationLogEntriesStub,
computationStatsClient = computationStatsStub,
privateKeyStore = privateKeyStore,
certificateClient = certificateStub,
workerStubs = workerStubs,
cryptoWorker = mockCryptoWorker,
workLockDuration = Duration.ofMinutes(5),
openTelemetry = GlobalOpenTelemetry.get(),
requestChunkSizeBytes = 20,
maximumAttempts = 2,
)
}

private suspend fun getHmssComputationDetails(role: RoleInComputation): ComputationDetails {
Expand Down Expand Up @@ -467,7 +418,8 @@ class HonestMajorityShareShuffleMillTest {
requisitions = REQUISITIONS,
)

firstWorkerMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_ONE_NAME)
mill.pollAndProcessNextComputation()

assertThat(fakeComputationDb[LOCAL_ID])
.isEqualTo(
Expand Down Expand Up @@ -521,7 +473,8 @@ class HonestMajorityShareShuffleMillTest {
requisitions = REQUISITIONS,
)

firstWorkerMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_ONE_NAME)
mill.pollAndProcessNextComputation()

val updatedToken = fakeComputationDb[LOCAL_ID]

Expand Down Expand Up @@ -602,7 +555,8 @@ class HonestMajorityShareShuffleMillTest {
requisitions = requisitions,
)

secondWorkerMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_TWO_NAME)
mill.pollAndProcessNextComputation()

assertThat(fakeComputationDb[LOCAL_ID])
.isEqualTo(
Expand Down Expand Up @@ -710,7 +664,8 @@ class HonestMajorityShareShuffleMillTest {
}
}

firstWorkerMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_ONE_NAME)
mill.pollAndProcessNextComputation()

assertThat(fakeComputationDb[LOCAL_ID])
.isEqualTo(
Expand Down Expand Up @@ -806,7 +761,8 @@ class HonestMajorityShareShuffleMillTest {
whenever(mockCertificates.getCertificate(any()))
.thenThrow(Status.NOT_FOUND.asRuntimeException())

firstWorkerMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_ONE_NAME)
mill.pollAndProcessNextComputation()

assertThat(fakeComputationDb[LOCAL_ID])
.isEqualTo(
Expand Down Expand Up @@ -869,7 +825,8 @@ class HonestMajorityShareShuffleMillTest {
}
}

aggregatorMill.pollAndProcessNextComputation()
val mill = createHmssMill(DUCHY_THREE_NAME)
mill.pollAndProcessNextComputation()

assertThat(fakeComputationDb[LOCAL_ID])
.isEqualTo(
Expand Down

0 comments on commit afc4972

Please sign in to comment.