Weighted sampling of items in Python
Let’s say you need to produce realistically distributed random selections from a list.
For example, let’s say you’re writing a US meteor simulation game, and want to select the state the meteor will land in. Assuming that the impact site will be completely geographically random, then the following example will select a state, with a statistically accurate weighting.
states = { "Alabama": 52419, "Alaska": 663267, "Arizona": 113998, "Arkansas": 53179, "California": 163695, "Colorado": 104093, "Connecticut": 5544, "Delaware": 2489, "District of Columbia": 68, "Florida": 65755, "Georgia": 59425, "Hawaii": 10931, "Idaho": 83570, "Illinois": 57914, "Indiana": 36418, "Iowa": 56271, "Kansas": 82276, "Kentucky": 40410, "Louisiana": 51840, "Maine": 35385, "Maryland": 12407, "Massachusetts": 10555, "Michigan": 96716, "Minnesota": 86938, "Mississippi": 48431, "Missouri": 69704, "Montana": 147042, "Nebraska": 77353, "Nevada": 110560, "New Hampshire": 9350, "New Jersey": 8722, "New Mexico": 121589, "New York": 54556, "North Carolina": 53818, "North Dakota": 70700, "Ohio": 44825, "Oklahoma": 69899, "Oregon": 98380, "Pennsylvania": 46056, "Rhode Island": 1545, "South Carolina": 32020, "South Dakota": 77353, "Tennessee": 42144, "Texas": 268580, "Utah": 84898, "Vermont": 9615, "Virginia": 42774, "Washington": 71300, "West Virginia": 24230, "Wisconsin ": 65498, "Wyoming": 97813 } print weighted(states)
Here is the function:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import collections import random import bisect def weighted(items): # items looks like: {"A": 1, "B": 2, "C": 10} total = sum(items.values()) weighted = [""] weights = [] current = 0 for item, weight in items.items(): weighted.append(item) weights.append(current) current += weight return weighted[bisect.bisect(weights, random.randint(0, total-1))] def test(items, samples): factor = samples * (1/float(sum(items.values()))) results = collections.Counter([weighted(items) for i in range(samples)]) results = dict((item, count / factor) for item, count in results.items()) import pprint pprint.pprint(dict(results)) if __name__ == "__main__": test({"a": 1, "b": 2, "c": 1}, 100000) |