Skip to content

Commit

Permalink
Joint mechanisms (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonnyascott authored Dec 2, 2024
1 parent 1026c6f commit ca8af99
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/source/reference/privacy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Privacy mechanisms
.. automodule:: pfl.privacy.ftrl_mechanism
:members:

.. automodule:: pfl.privacy.joint_mechanism
:members:

Privacy accountants
-------------------

Expand Down
150 changes: 150 additions & 0 deletions pfl/privacy/joint_mechanism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright © 2023-2024 Apple Inc.
'''
Joint mechanism for combining multiple mechanisms into one.
'''

from typing import Dict, List, Optional, Set, Tuple

from pfl.metrics import Metrics, StringMetricName
from pfl.stats import MappedVectorStatistics, TrainingStatistics

from .privacy_mechanism import CentrallyApplicablePrivacyMechanism


class JointMechanism(CentrallyApplicablePrivacyMechanism):
"""
Constructs a new CentrallyApplicablePrivacyMechanism from existing ones.
Each existing mechanism is applied to a disjoint subset of the client
statistics keys. As such JointMechanism can only be applied to client
statistics of type MappedVectorStatistics.
:param mechanisms_and_keys:
Dictionary which maps a name of a mechanism to a tuple consisting of
the corresponding CentrallyApplicablePrivacyMechanism and a list specifying
which keys of the user training statistics that mechanism should be applied
to. This list can contain strings of the following forms:
* key_name - an exact key name as appearing in the user statistics,
* f'{key_prefix}/' - matches to all user statistics keys of the form
f'{key_prefix}/{any_string}'.
Finally, note that the names of each of the mechanisms must be distinct
for the purpose of naming the corresponding Metrics.
"""

def __init__(
self,
mechanisms_and_keys: Dict[str,
Tuple[CentrallyApplicablePrivacyMechanism,
List[str]]]):
if len(set(mechanisms_and_keys.keys())) < len(
mechanisms_and_keys.keys()):
raise ValueError('Mechanism names must be unique.')
self.mechanisms_and_keys = mechanisms_and_keys

def constrain_sensitivity(
self,
statistics: TrainingStatistics,
name_formatting_fn=lambda n: StringMetricName(n),
seed: Optional[int] = None) -> Tuple[TrainingStatistics, Metrics]:

if not isinstance(statistics, MappedVectorStatistics):
raise TypeError(
'Statistics must be of type MappedVectorStatistics.')

clipped_statistics: MappedVectorStatistics = MappedVectorStatistics(
weight=statistics.weight)
metrics = Metrics()
client_statistics_keys = set(statistics.keys())
for mechanism_name, (
mechanism, mechanism_keys) in self.mechanisms_and_keys.items():

def mechanism_name_formatting_fn(n, prefix=mechanism_name):
return name_formatting_fn(f'{prefix} | {n}')

# Extract client statistics keys that match the keys for current mechanism
sub_statistics: MappedVectorStatistics = MappedVectorStatistics()
for key in mechanism_keys:
if key in client_statistics_keys: # exact key name
sub_statistics[key] = statistics[key]
client_statistics_keys.remove(key)
else:
assert key[
-1] == '/', f"{key} does not appear as a key in the client statistics."
for client_key in statistics:
if client_key.startswith(
key): # matches f'{key_prefix}/'
sub_statistics[client_key] = statistics[client_key]
client_statistics_keys.discard(client_key)

# Clip statistics using mechanism
clipped_sub_statistics, sub_metrics = mechanism.constrain_sensitivity(
sub_statistics, mechanism_name_formatting_fn, seed)

# Recombine clipped statistics and metrics
for key in clipped_sub_statistics:
clipped_statistics[key] = clipped_sub_statistics[key]
metrics = metrics | sub_metrics

if len(client_statistics_keys) > 0:
raise ValueError(
f'Not all client statistics have been clipped. '
f'These keys are missing from mechanisms_and_keys: {client_statistics_keys}.'
)

return clipped_statistics, metrics

def add_noise(
self,
statistics: TrainingStatistics,
cohort_size: int,
name_formatting_fn=lambda n: StringMetricName(n),
seed: Optional[int] = None) -> Tuple[TrainingStatistics, Metrics]:

if not isinstance(statistics, MappedVectorStatistics):
raise TypeError(
'Statistics must be of type MappedVectorStatistics.')

noised_statistics: MappedVectorStatistics = MappedVectorStatistics(
weight=statistics.weight)
metrics = Metrics()
client_statistics_keys = set(statistics.keys())
for mechanism_name, (
mechanism, mechanism_keys) in self.mechanisms_and_keys.items():

def mechanism_name_formatting_fn(n, prefix=mechanism_name):
return name_formatting_fn(f'{prefix} | {n}')

# Extract client statistics keys that match the keys for current mechanism
sub_statistics: MappedVectorStatistics = MappedVectorStatistics()
for key in mechanism_keys:
if key in client_statistics_keys: # exact key name
sub_statistics[key] = statistics[key]
client_statistics_keys.remove(key)
else:
assert key[
-1] == '/', f"{key} does not appear as a key in the client statistics."
for client_key in statistics:
if client_key.startswith(
key): # matches f'{key_prefix}/'
sub_statistics[client_key] = statistics[client_key]
client_statistics_keys.discard(client_key)

# Apply noise using mechanism
noised_sub_statistics, sub_metrics = mechanism.add_noise(
sub_statistics, cohort_size, mechanism_name_formatting_fn,
seed)

# Recombine noised statistics and metrics
for key in noised_sub_statistics:
noised_statistics[key] = noised_sub_statistics[key]
metrics = metrics | sub_metrics

if len(client_statistics_keys) > 0:
raise ValueError(
f'Not all client statistics have been noised. '
f'These keys are missing from mechanisms_and_keys: {client_statistics_keys}.'
)

return noised_statistics, metrics
113 changes: 112 additions & 1 deletion tests/privacy/test_privacy_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from pfl.hyperparam import ModelHyperParams
from pfl.hyperparam.base import AlgorithmHyperParams
from pfl.internal.ops import check_mlx_installed, get_pytorch_major_version, get_tf_major_version
from pfl.metrics import Metrics, get_overall_value
from pfl.metrics import Metrics, Weighted, get_overall_value
from pfl.privacy import compute_parameters
from pfl.privacy.approximate_mechanism import SquaredErrorLocalPrivacyMechanism
from pfl.privacy.gaussian_mechanism import GaussianMechanism
from pfl.privacy.joint_mechanism import JointMechanism
from pfl.privacy.laplace_mechanism import LaplaceMechanism
from pfl.privacy.privacy_mechanism import (
CentrallyAppliedPrivacyMechanism,
Expand Down Expand Up @@ -484,3 +485,113 @@ def test_norm_clipping_only(self, order, clipping_bound, expected_arrays,
is_local_privacy=True)
assert get_overall_value(
metrics[name]) == pytest.approx(expected_clip_fraction)

@pytest.mark.parametrize('ops_module', framework_fixtures)
def test_joint_mechanism(self, ops_module, fix_global_random_seeds):

def get_mock_mechanism_call(mechanism_fn_name):

def mock_mechanism_call(*args):
statistics = args[0]
name_formatting_fn = args[-2]
metrics = Metrics([
(name_formatting_fn(f'{mechanism_fn_name}'),
Weighted.from_unweighted(
sum([x.shape[0] for x in statistics.values()]))),
])
return statistics, metrics

return mock_mechanism_call

first_mechanism_name = 'laplace_mechanism'
laplace_mechanism = MagicMock()
laplace_mechanism.constrain_sensitivity = MagicMock(
side_effect=get_mock_mechanism_call('constrain_sensitivity'))
laplace_mechanism.add_noise = MagicMock(
side_effect=get_mock_mechanism_call('add_noise'))
laplace_keys = ['laplace/', 'laplace_exact']

second_mechanism_name = 'gaussian_mechanism'
gaussian_mechanism = MagicMock()
gaussian_mechanism.constrain_sensitivity = MagicMock(
side_effect=get_mock_mechanism_call('constrain_sensitivity'))
gaussian_mechanism.add_noise = MagicMock(
side_effect=get_mock_mechanism_call('add_noise'))
gaussian_keys = ['gaussian_exact1', 'gaussian_exact2']

mechanisms_and_keys = {
first_mechanism_name: (laplace_mechanism, laplace_keys),
second_mechanism_name: (gaussian_mechanism, gaussian_keys)
}
joint_mechanism = JointMechanism(mechanisms_and_keys)

input_stats_keys = [
'laplace/subpath1', 'laplace/subpath2', 'gaussian_exact1',
'gaussian_exact2', 'laplace_exact'
]
input_stats = MappedVectorStatistics(dict(
zip(input_stats_keys, [
np.ones(i + 1, dtype=np.float32)
for i in range(len(input_stats_keys))
])),
weight=10)
input_tensor_stats = self._to_tensor_stats(input_stats, ops_module)

seed = 0
noised_arrays, metrics = joint_mechanism.postprocess_one_user(
stats=input_tensor_stats, user_context=MagicMock(seed=seed))

# Weight of statistics after is same is before
assert noised_arrays.weight == input_tensor_stats.weight

# check that each mechanism was applied to the correct portions of the user statistics
assert set(laplace_mechanism.constrain_sensitivity.call_args[0]
[0].keys()) == {
'laplace/subpath1', 'laplace/subpath2', 'laplace_exact'
}
assert set(gaussian_mechanism.constrain_sensitivity.call_args[0]
[0].keys()) == {'gaussian_exact1', 'gaussian_exact2'}
assert set(laplace_mechanism.add_noise.call_args[0][0].keys()) == {
'laplace/subpath1', 'laplace/subpath2', 'laplace_exact'
}
assert set(gaussian_mechanism.add_noise.call_args[0][0].keys()) == {
'gaussian_exact1', 'gaussian_exact2'
}

# Laplace should have been applied to shapes 1 + 2 + 5 = 8, and gaussian to 3 + 4 = 7
expected_metrics = Metrics()
for name, expected_val in zip(
[first_mechanism_name, second_mechanism_name], [8, 7]):
for fn_name in ['constrain_sensitivity', 'add_noise']:
expected_metrics[PrivacyMetricName(
f'{name} | {fn_name}',
is_local_privacy=True)] = expected_val

# Check that returned metric names and vals are as expected
for name, expected_metric in expected_metrics:
assert get_overall_value(metrics[name]) == expected_metric

# Should raise Value error when a statistics key is not present in mechanisms_and_keys
mechanisms_and_keys_missing_key = {
first_mechanism_name: (laplace_mechanism, laplace_keys[:-1]),
second_mechanism_name: (gaussian_mechanism, gaussian_keys)
}

joint_mechanism = JointMechanism(mechanisms_and_keys_missing_key)

with pytest.raises(ValueError):
noised_arrays, metrics = joint_mechanism.postprocess_one_user(
stats=input_tensor_stats, user_context=MagicMock(seed=seed))

# Should raise assertion error when an exact key name is provided that is not present in statistics.keys()
mechanisms_and_keys_extra_key = {
first_mechanism_name: (laplace_mechanism, laplace_keys),
second_mechanism_name:
(gaussian_mechanism, [*gaussian_keys, 'extra_key'])
}

joint_mechanism = JointMechanism(mechanisms_and_keys_extra_key)

with pytest.raises(AssertionError):
noised_arrays, metrics = joint_mechanism.postprocess_one_user(
stats=input_tensor_stats, user_context=MagicMock(seed=seed))

0 comments on commit ca8af99

Please sign in to comment.