Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/eight-radios-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Checkpoints`: Add a new checkpoint variant `Checkpoint256` using `uint256` type for the value and key.
203 changes: 203 additions & 0 deletions contracts/utils/structs/Checkpoints.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,209 @@ library Checkpoints {
*/
error CheckpointUnorderedInsertion();

struct Trace256 {
Checkpoint256[] _checkpoints;
}

struct Checkpoint256 {
uint256 _key;
uint256 _value;
}

/**
* @dev Pushes a (`key`, `value`) pair into a Trace256 so that it is stored as the checkpoint.
*
* Returns previous value and new value.
*
* IMPORTANT: Never accept `key` as a user input, since an arbitrary `type(uint256).max` key set will disable the
* library.
*/
function push(
Trace256 storage self,
uint256 key,
uint256 value
) internal returns (uint256 oldValue, uint256 newValue) {
return _insert(self._checkpoints, key, value);
}

/**
* @dev Returns the value in the first (oldest) checkpoint with key greater or equal than the search key, or zero if
* there is none.
*/
function lowerLookup(Trace256 storage self, uint256 key) internal view returns (uint256) {
uint256 len = self._checkpoints.length;
uint256 pos = _lowerBinaryLookup(self._checkpoints, key, 0, len);
return pos == len ? 0 : _unsafeAccess(self._checkpoints, pos)._value;
}

/**
* @dev Returns the value in the last (most recent) checkpoint with key lower or equal than the search key, or zero
* if there is none.
*/
function upperLookup(Trace256 storage self, uint256 key) internal view returns (uint256) {
uint256 len = self._checkpoints.length;
uint256 pos = _upperBinaryLookup(self._checkpoints, key, 0, len);
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the last (most recent) checkpoint with key lower or equal than the search key, or zero
* if there is none.
*
* NOTE: This is a variant of {upperLookup} that is optimized to find "recent" checkpoint (checkpoints with high
* keys).
*/
function upperLookupRecent(Trace256 storage self, uint256 key) internal view returns (uint256) {
uint256 len = self._checkpoints.length;

uint256 low = 0;
uint256 high = len;

if (len > 5) {
uint256 mid = len - Math.sqrt(len);
if (key < _unsafeAccess(self._checkpoints, mid)._key) {
high = mid;
} else {
low = mid + 1;
}
}

uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the most recent checkpoint, or zero if there are no checkpoints.
*/
function latest(Trace256 storage self) internal view returns (uint256) {
uint256 pos = self._checkpoints.length;
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns whether there is a checkpoint in the structure (i.e. it is not empty), and if so the key and value
* in the most recent checkpoint.
*/
function latestCheckpoint(Trace256 storage self) internal view returns (bool exists, uint256 _key, uint256 _value) {
uint256 pos = self._checkpoints.length;
if (pos == 0) {
return (false, 0, 0);
} else {
Checkpoint256 storage ckpt = _unsafeAccess(self._checkpoints, pos - 1);
return (true, ckpt._key, ckpt._value);
}
}

/**
* @dev Returns the number of checkpoints.
*/
function length(Trace256 storage self) internal view returns (uint256) {
return self._checkpoints.length;
}

/**
* @dev Returns checkpoint at given position.
*/
function at(Trace256 storage self, uint32 pos) internal view returns (Checkpoint256 memory) {
return self._checkpoints[pos];
}

/**
* @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one.
*/
function _insert(
Checkpoint256[] storage self,
uint256 key,
uint256 value
) private returns (uint256 oldValue, uint256 newValue) {
uint256 pos = self.length;

if (pos > 0) {
Checkpoint256 storage last = _unsafeAccess(self, pos - 1);
uint256 lastKey = last._key;
uint256 lastValue = last._value;

// Checkpoint keys must be non-decreasing.
if (lastKey > key) {
revert CheckpointUnorderedInsertion();
}

// Update or push new checkpoint
if (lastKey == key) {
last._value = value;
} else {
self.push(Checkpoint256({_key: key, _value: value}));
}
return (lastValue, value);
} else {
self.push(Checkpoint256({_key: key, _value: value}));
return (0, value);
}
}

/**
* @dev Return the index of the first (oldest) checkpoint with key strictly bigger than the search key, or `high`
* if there is none. `low` and `high` define a section where to do the search, with inclusive `low` and exclusive
* `high`.
*
* WARNING: `high` should not be greater than the array's length.
*/
function _upperBinaryLookup(
Checkpoint256[] storage self,
uint256 key,
uint256 low,
uint256 high
) private view returns (uint256) {
while (low < high) {
uint256 mid = Math.average(low, high);
if (_unsafeAccess(self, mid)._key > key) {
high = mid;
} else {
low = mid + 1;
}
}
return high;
}

/**
* @dev Return the index of the first (oldest) checkpoint with key greater or equal than the search key, or `high`
* if there is none. `low` and `high` define a section where to do the search, with inclusive `low` and exclusive
* `high`.
*
* WARNING: `high` should not be greater than the array's length.
*/
function _lowerBinaryLookup(
Checkpoint256[] storage self,
uint256 key,
uint256 low,
uint256 high
) private view returns (uint256) {
while (low < high) {
uint256 mid = Math.average(low, high);
if (_unsafeAccess(self, mid)._key < key) {
low = mid + 1;
} else {
high = mid;
}
}
return high;
}

/**
* @dev Access an element of the array without performing bounds check. The position is assumed to be within bounds.
*/
function _unsafeAccess(
Checkpoint256[] storage self,
uint256 pos
) private pure returns (Checkpoint256 storage result) {
assembly {
mstore(0, self.slot)
result.slot := add(keccak256(0, 0x20), mul(pos, 2))
}
}

struct Trace224 {
Checkpoint224[] _checkpoints;
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate/templates/Checkpoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ function _unsafeAccess(
) private pure returns (${opts.checkpointTypeName} storage result) {
assembly {
mstore(0, self.slot)
result.slot := add(keccak256(0, 0x20), pos)
result.slot := add(keccak256(0, 0x20), ${opts.checkpointSize === 1 ? 'pos' : `mul(pos, ${opts.checkpointSize})`})
}
}
`;
Expand Down
5 changes: 3 additions & 2 deletions scripts/generate/templates/Checkpoints.opts.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// OPTIONS
const VALUE_SIZES = [224, 208, 160];
const VALUE_SIZES = [256, 224, 208, 160];

const defaultOpts = size => ({
historyTypeName: `Trace${size}`,
checkpointTypeName: `Checkpoint${size}`,
checkpointFieldName: '_checkpoints',
keyTypeName: `uint${256 - size}`,
checkpointSize: size < 256 ? 1 : 2,
keyTypeName: size < 256 ? `uint${256 - size}` : 'uint256',
keyFieldName: '_key',
valueTypeName: `uint${size}`,
valueFieldName: '_value',
Expand Down
6 changes: 5 additions & 1 deletion scripts/generate/templates/Checkpoints.t.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ Checkpoints.${opts.historyTypeName} internal _ckpts;
function _bound${capitalize(opts.keyTypeName)}(${opts.keyTypeName} x, ${opts.keyTypeName} min, ${
opts.keyTypeName
} max) internal pure returns (${opts.keyTypeName}) {
return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max)));
return ${
opts.keyTypeName === 'uint256'
? 'bound(x, min, max)'
: `SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max)))`
};
}

function _prepareKeys(${opts.keyTypeName}[] memory keys, ${opts.keyTypeName} maxSpread) internal pure {
Expand Down
108 changes: 108 additions & 0 deletions test/utils/structs/Checkpoints.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,114 @@ import {Test} from "forge-std/Test.sol";
import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol";
import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol";

contract CheckpointsTrace256Test is Test {
using Checkpoints for Checkpoints.Trace256;

// Maximum gap between keys used during the fuzzing tests: the `_prepareKeys` function will make sure that
// key#n+1 is in the [key#n, key#n + _KEY_MAX_GAP] range.
uint8 internal constant _KEY_MAX_GAP = 64;

Checkpoints.Trace256 internal _ckpts;

// helpers
function _boundUint256(uint256 x, uint256 min, uint256 max) internal pure returns (uint256) {
return bound(x, min, max);
}

function _prepareKeys(uint256[] memory keys, uint256 maxSpread) internal pure {
uint256 lastKey = 0;
for (uint256 i = 0; i < keys.length; ++i) {
uint256 key = _boundUint256(keys[i], lastKey, lastKey + maxSpread);
keys[i] = key;
lastKey = key;
}
}

function _assertLatestCheckpoint(bool exist, uint256 key, uint256 value) internal view {
(bool _exist, uint256 _key, uint256 _value) = _ckpts.latestCheckpoint();
assertEq(_exist, exist);
assertEq(_key, key);
assertEq(_value, value);
}

// tests
function testPush(uint256[] memory keys, uint256[] memory values, uint256 pastKey) public {
vm.assume(values.length > 0 && values.length <= keys.length);
_prepareKeys(keys, _KEY_MAX_GAP);

// initial state
assertEq(_ckpts.length(), 0);
assertEq(_ckpts.latest(), 0);
_assertLatestCheckpoint(false, 0, 0);

uint256 duplicates = 0;
for (uint256 i = 0; i < keys.length; ++i) {
uint256 key = keys[i];
uint256 value = values[i % values.length];
if (i > 0 && key == keys[i - 1]) ++duplicates;

// push
_ckpts.push(key, value);

// check length & latest
assertEq(_ckpts.length(), i + 1 - duplicates);
assertEq(_ckpts.latest(), value);
_assertLatestCheckpoint(true, key, value);
}

if (keys.length > 0) {
uint256 lastKey = keys[keys.length - 1];
if (lastKey > 0) {
pastKey = _boundUint256(pastKey, 0, lastKey - 1);

vm.expectRevert();
this.push(pastKey, values[keys.length % values.length]);
}
}
}

// used to test reverts
function push(uint256 key, uint256 value) external {
_ckpts.push(key, value);
}

function testLookup(uint256[] memory keys, uint256[] memory values, uint256 lookup) public {
vm.assume(values.length > 0 && values.length <= keys.length);
_prepareKeys(keys, _KEY_MAX_GAP);

uint256 lastKey = keys.length == 0 ? 0 : keys[keys.length - 1];
lookup = _boundUint256(lookup, 0, lastKey + _KEY_MAX_GAP);

uint256 upper = 0;
uint256 lower = 0;
uint256 lowerKey = type(uint256).max;
for (uint256 i = 0; i < keys.length; ++i) {
uint256 key = keys[i];
uint256 value = values[i % values.length];

// push
_ckpts.push(key, value);

// track expected result of lookups
if (key <= lookup) {
upper = value;
}
// find the first key that is not smaller than the lookup key
if (key >= lookup && (i == 0 || keys[i - 1] < lookup)) {
lowerKey = key;
}
if (key == lowerKey) {
lower = value;
}
}

// check lookup
assertEq(_ckpts.lowerLookup(lookup), lower);
assertEq(_ckpts.upperLookup(lookup), upper);
assertEq(_ckpts.upperLookupRecent(lookup), upper);
}
}

contract CheckpointsTrace224Test is Test {
using Checkpoints for Checkpoints.Trace224;

Expand Down
Loading