-
Notifications
You must be signed in to change notification settings - Fork 386
Add PRVAccountant #493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Add PRVAccountant #493
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
a53c353
Initial implementation of PRV Accountant
tcbegley 0dd2788
Fix linting issues
tcbegley a8ee88a
Add docstrings and warning
tcbegley 2663fb0
Add PRVAcountant regression tests
tcbegley 7a1f42d
Remove spurious import added by VSCode
tcbegley 9456bb8
Sort imports
tcbegley 334b5a8
Add PRVAccountant test
bb2f182
Fix values now that accountant's defualts have changed
35e01af
Remove warning
d011f40
Fix lint and assert_allclose
e92938d
Disable coveralls. Make prv accountant test not nightly. Make prv acc…
d569416
Couple of fixes
9c67382
Remove long tests for PRV (keeping the ones that run on time). Updati…
b7a4b93
Change accountant
4632c73
Undo deterministic for test
4fbc963
Removing deterministic for dp_layers/dp_rnn_test.py
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from .compose import compose_heterogeneous | ||
| from .domain import Domain, compute_safe_domain_size | ||
| from .prvs import ( | ||
| DiscretePRV, | ||
| PoissonSubsampledGaussianPRV, | ||
| TruncatedPrivacyRandomVariable, | ||
| discretize, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "DiscretePRV", | ||
| "Domain", | ||
| "PoissonSubsampledGaussianPRV", | ||
| "TruncatedPrivacyRandomVariable", | ||
| "compose_heterogeneous", | ||
| "compute_safe_domain_size", | ||
| "discretize", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| from typing import List | ||
|
|
||
| import numpy as np | ||
| from scipy.fft import irfft, rfft | ||
| from scipy.signal import convolve | ||
|
|
||
| from .prvs import DiscretePRV | ||
|
|
||
|
|
||
| def _compose_fourier(dprv: DiscretePRV, num_self_composition: int) -> DiscretePRV: | ||
| if len(dprv) % 2 != 0: | ||
| raise ValueError("Can only compose evenly sized discrete PRVs") | ||
|
|
||
| composed_pmf = irfft(rfft(dprv.pmf) ** num_self_composition) | ||
|
|
||
| m = num_self_composition - 1 | ||
| if num_self_composition % 2 == 0: | ||
| m += len(composed_pmf) // 2 | ||
| composed_pmf = np.roll(composed_pmf, m) | ||
|
|
||
| domain = dprv.domain.shift_right(dprv.domain.shifts * (num_self_composition - 1)) | ||
|
|
||
| return DiscretePRV(pmf=composed_pmf, domain=domain) | ||
|
|
||
|
|
||
| def _compose_two(dprv_left: DiscretePRV, dprv_right: DiscretePRV) -> DiscretePRV: | ||
| pmf = convolve(dprv_left.pmf, dprv_right.pmf, mode="same") | ||
| domain = dprv_left.domain.shift_right(dprv_right.domain.shifts) | ||
| return DiscretePRV(pmf=pmf, domain=domain) | ||
|
|
||
|
|
||
| def _compose_convolution_tree(dprvs: List[DiscretePRV]) -> DiscretePRV: | ||
| # repeatedly convolve neighbouring PRVs until we only have one left | ||
| while len(dprvs) > 1: | ||
| dprvs_conv = [] | ||
| if len(dprvs) % 2 == 1: | ||
| dprvs_conv.append(dprvs.pop()) | ||
|
|
||
| for dprv_left, dprv_right in zip(dprvs[:-1:2], dprvs[1::2]): | ||
| dprvs_conv.append(_compose_two(dprv_left, dprv_right)) | ||
|
|
||
| dprvs = dprvs_conv | ||
| return dprvs[0] | ||
|
|
||
|
|
||
| def compose_heterogeneous( | ||
| dprvs: List[DiscretePRV], num_self_compositions: List[int] | ||
| ) -> DiscretePRV: | ||
| r""" | ||
| Compose a heterogenous list of PRVs with multiplicity. We use FFT to compose | ||
| identical PRVs with themselves first, then pairwise convolve the remaining PRVs. | ||
|
|
||
| This is the approach taken in https://github.com/microsoft/prv_accountant | ||
| """ | ||
| if len(dprvs) != len(num_self_compositions): | ||
| raise ValueError("dprvs and num_self_compositions must have the same length") | ||
|
|
||
| dprvs = [ | ||
| _compose_fourier(dprv, num_self_composition) | ||
| for dprv, num_self_composition in zip(dprvs, num_self_compositions) | ||
| ] | ||
| return _compose_convolution_tree(dprvs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Sequence | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ...rdp import RDPAccountant | ||
|
|
||
|
|
||
| @dataclass | ||
| class Domain: | ||
| r""" | ||
| Stores relevant information about the domain on which PRVs are discretized, and | ||
| includes a few convenience methods for manipulating it. | ||
| """ | ||
| t_min: float | ||
| t_max: float | ||
| size: int | ||
| shifts: float = 0.0 | ||
|
|
||
| def __post_init__(self): | ||
| if not isinstance(self.size, int): | ||
| raise TypeError("`size` must be an integer") | ||
| if self.size % 2 != 0: | ||
| raise ValueError("`size` must be even") | ||
|
|
||
| @classmethod | ||
| def create_aligned(cls, t_min: float, t_max: float, dt: float) -> "Domain": | ||
| t_min = np.floor(t_min / dt) * dt | ||
| t_max = np.ceil(t_max / dt) * dt | ||
|
|
||
| size = int(np.round((t_max - t_min) / dt)) + 1 | ||
|
|
||
| if size % 2 == 1: | ||
| size += 1 | ||
| t_max += dt | ||
|
|
||
| domain = cls(t_min, t_max, size) | ||
|
|
||
| if np.abs(domain.dt - dt) / dt >= 1e-8: | ||
| raise RuntimeError | ||
|
|
||
| return domain | ||
|
|
||
| def shift_right(self, dt: float) -> "Domain": | ||
| return Domain( | ||
| t_min=self.t_min + dt, | ||
| t_max=self.t_max + dt, | ||
| size=self.size, | ||
| shifts=self.shifts + dt, | ||
| ) | ||
|
|
||
| @property | ||
| def dt(self): | ||
| return (self.t_max - self.t_min) / (self.size - 1) | ||
|
|
||
| @property | ||
| def ts(self): | ||
| return np.linspace(self.t_min, self.t_max, self.size) | ||
|
|
||
| def __getitem__(self, i: int) -> float: | ||
| return self.t_min + i * self.dt | ||
|
|
||
|
|
||
| def compute_safe_domain_size( | ||
| prvs, | ||
| max_self_compositions: Sequence[int], | ||
| eps_error: float, | ||
| delta_error: float, | ||
| ) -> float: | ||
| """ | ||
| Compute safe domain size for the discretization of the PRVs. | ||
|
|
||
| For details about this algorithm, see remark 5.6 in | ||
| https://www.microsoft.com/en-us/research/publication/numerical-composition-of-differential-privacy/ | ||
| """ | ||
| total_compositions = sum(max_self_compositions) | ||
|
|
||
| rdp_accountant = RDPAccountant() | ||
| for prv, max_self_composition in zip(prvs, max_self_compositions): | ||
| rdp_accountant.history.append( | ||
| (prv.noise_multiplier, prv.sample_rate, max_self_composition) | ||
| ) | ||
|
|
||
| L_max = rdp_accountant.get_epsilon(delta_error / 4) | ||
|
|
||
| for prv, max_self_composition in zip(prvs, max_self_compositions): | ||
| rdp_accountant = RDPAccountant() | ||
| rdp_accountant.history = [(prv.noise_multiplier, prv.sample_rate, 1)] | ||
| L_max = max( | ||
| L_max, | ||
| rdp_accountant.get_epsilon(delta=delta_error / (8 * total_compositions)), | ||
| ) | ||
|
|
||
| # FIXME: this implementation is adapted from the code accompanying the paper, but | ||
| # disagrees subtly with the theory from remark 5.6. It's not immediately clear this | ||
| # gives the right guarantees in all cases, though it's fine for eps_error < 1 and | ||
| # hence generic cases. | ||
| # cf. https://github.com/microsoft/prv_accountant/discussions/34 | ||
| return max(L_max, eps_error) + 3 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in two minds about whether to keep this as it is which follows the implementation from microsoft/prv_accountant or to follow the paper more precisely. It's not immediately clear to me that this implementation always satisfies the conditions required in the paper without making an assumption like
eps_error < 1.The changes required are minor and would look something like the following