Skip to content

Commit eb93e56

Browse files
committed
Add Sequence.randomElement()
The stdlib's randomElement() method extends Collection, since it operates by selecting a random index. This adds randomElement() as a Sequence method, specializing the reservoir sampling algorithm from randomSample(count:) for a count of 1.
1 parent 11fda51 commit eb93e56

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

Sources/Algorithms/RandomSample.swift

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,61 @@ extension Sequence {
235235
return randomSample(count: k, using: &g)
236236
}
237237
}
238+
239+
//===----------------------------------------------------------------------===//
240+
// randomElement()
241+
//===----------------------------------------------------------------------===//
242+
243+
// This method is a single-element specialization of `randomSample(count:)`,
244+
// and extends the stdlib's `randomElement()` functionality (which is available
245+
// for collections) down to sequences.
246+
247+
extension Sequence {
248+
/// Randomly selects an element from this sequence, using the given generator
249+
/// as the source of randomness.
250+
///
251+
/// - Parameter rng: The random number generator to use for sampling.
252+
/// - Returns: A random element. If the sequence has no elements, this method
253+
/// returns `nil`.
254+
///
255+
/// - Complexity: O(*n*), where *n* is the length of the sequence.
256+
@inlinable
257+
public func randomElement<G: RandomNumberGenerator>(
258+
using rng: inout G
259+
) -> Element? {
260+
var iterator = makeIterator()
261+
guard var result = iterator.next() else {
262+
return nil
263+
}
264+
265+
var w = 1.0
266+
while true {
267+
w *= nextW(k: 1, using: &rng)
268+
var offset = nextOffset(w: w, using: &rng)
269+
while offset > 0, let _ = iterator.next() {
270+
offset -= 1
271+
}
272+
guard let nextElement = iterator.next()
273+
else { break }
274+
275+
result = nextElement
276+
}
277+
278+
return result
279+
}
280+
281+
/// Randomly selects an element from this sequence.
282+
///
283+
/// This method is equivalent to calling `randomSample(using:)`, passing in
284+
/// the system's default random generator.
285+
///
286+
/// - Returns: A random element. If the sequence has no elements, this method
287+
/// returns `nil`.
288+
///
289+
/// - Complexity: O(*n*), where *n* is the length of the sequence.
290+
@inlinable
291+
public func randomElement() -> Element? {
292+
var g = SystemRandomNumberGenerator()
293+
return randomElement(using: &g)
294+
}
295+
}

Tests/SwiftAlgorithmsTests/RandomSampleTests.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,37 @@ final class RandomSampleTests: XCTestCase {
132132
almostAllZero = AlmostAllZeroGenerator(seed: 0)
133133
_ = c.randomSample(count: k, using: &almostAllZero) // must not crash
134134
}
135+
136+
func testSequenceRandomElement() {
137+
XCTAssertNil(emptySequence.randomElement())
138+
139+
let expectedRange = expectedRange(for: iterations)
140+
let randomElements = (0..<(n*iterations)).map { _ in
141+
s.randomElement()!
142+
}.frequencies
143+
XCTAssertEqual(randomElements.count, n)
144+
XCTAssert(randomElements.values.allSatisfy { expectedRange.contains($0) })
145+
146+
let oneElementSequence = sequence(first: 0, next: { _ in nil })
147+
let randomOneElement = (0..<iterations).map { _ in
148+
oneElementSequence.randomElement()!
149+
}
150+
XCTAssert(randomOneElement.allSatisfy { $0 == 0 })
151+
152+
let twoElementSequence = sequence(state: [1, 2].makeIterator(), next: { $0.next() })
153+
let randomTwoElement = (0..<(2*iterations)).map { _ in
154+
twoElementSequence.randomElement()!
155+
}.frequencies
156+
XCTAssertEqual(randomTwoElement.count, 2)
157+
XCTAssert(randomElements.values.allSatisfy { expectedRange.contains($0) })
158+
}
159+
160+
func testSequenceRandomElementRepeatable() {
161+
let seed = UInt64.random(in: 0 ... .max)
162+
var generator = SplitMix64(seed: seed)
163+
let elements1 = (0..<500).map { _ in s.randomElement(using: &generator)! }
164+
generator = SplitMix64(seed: seed)
165+
let elements2 = (0..<500).map { _ in s.randomElement(using: &generator)! }
166+
XCTAssertEqual(elements1, elements2)
167+
}
135168
}

0 commit comments

Comments
 (0)