Skip to content

Commit 56dbd20

Browse files
authored
fix(cheatcodes): overflow in randomNumber w/range (#8361)
1 parent eff3f43 commit 56dbd20

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

crates/cheatcodes/src/utils.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,13 @@ impl Cheatcode for randomUint_1Call {
162162
ensure!(min <= max, "min must be less than or equal to max");
163163
// Generate random between range min..=max
164164
let mut rng = rand::thread_rng();
165-
let range = max - min + U256::from(1);
166-
let random_number = rng.gen::<U256>() % range + min;
165+
let exclusive_modulo = max - min;
166+
let mut random_number = rng.gen::<U256>();
167+
if exclusive_modulo != U256::MAX {
168+
let inclusive_modulo = exclusive_modulo + U256::from(1);
169+
random_number %= inclusive_modulo;
170+
}
171+
random_number += min;
167172
Ok(random_number.abi_encode())
168173
}
169174
}

testdata/default/cheats/RandomUint.t.sol

+12-12
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,27 @@ import "cheats/Vm.sol";
77
contract RandomUint is DSTest {
88
Vm constant vm = Vm(HEVM_ADDRESS);
99

10-
// All tests use `>=` and `<=` to verify that ranges are inclusive and that
11-
// a value of zero may be generated.
1210
function testRandomUint() public {
13-
uint256 rand = vm.randomUint();
14-
assertTrue(rand >= 0);
11+
vm.randomUint();
1512
}
1613

17-
function testRandomUint(uint256 min, uint256 max) public {
18-
vm.assume(max >= min);
19-
uint256 rand = vm.randomUint(min, max);
20-
assertTrue(rand >= min, "rand >= min");
21-
assertTrue(rand <= max, "rand <= max");
14+
function testRandomUintRangeOverflow() public {
15+
vm.randomUint(0, uint256(int256(-1)));
2216
}
2317

24-
function testRandomUint(uint256 val) public {
18+
function testRandomUintSame(uint256 val) public {
2519
uint256 rand = vm.randomUint(val, val);
2620
assertTrue(rand == val);
2721
}
2822

23+
function testRandomUintRange(uint256 min, uint256 max) public {
24+
vm.assume(max >= min);
25+
uint256 rand = vm.randomUint(min, max);
26+
assertTrue(rand >= min, "rand >= min");
27+
assertTrue(rand <= max, "rand <= max");
28+
}
29+
2930
function testRandomAddress() public {
30-
address rand = vm.randomAddress();
31-
assertTrue(rand >= address(0));
31+
vm.randomAddress();
3232
}
3333
}

0 commit comments

Comments
 (0)