def weightedSelection[T](items: Seq[WeightedItem[T]], numSelections:
Int, r: Random): Seq[T] = ???
Your task is to implement weightedSelection, such that it makes
'numSelections' random selections (with replacement) from 'items',
where the probability of selecting each item is proportional to its
weighting.
It is a random process. I'd expect the correctness of your algorithm
can only be verified stochastically, eg via repeated testing using
ScalaCheck. For a generous serving of bonus marks, implement a unit
test that provides some degree of confidence that your function works
correctly.
Example,contrived such that the selections occur in exactly the
expected proportions:
val items = Seq(WeightedItem("Red", 1d/6), WeightedItem("Blue", 2d/6),
WeightedItem("Green", 3d/6), )
//Seq(Green, Red, Green, Green, Red, Blue)
val selections = weightedSelection(items, 6, new Random())
I was reminded of the problem in Unit 3 of Sebastian Thrun's online
course "Programming A Robotic Car' (1), implementing Particle Filters.
But the problem comes up in alot of places, and I realize that I've
encountered the same basic problem multiple times before:
- Genetic algorithms: selecting next generation in proportion to fitness
- ScalaCheck test data generation: generating a weighted distribution
of different values
- Game AI: Selecting different actions weighted by a suitability metric
Finally, I'll mention that there are methods of solving this problem
considerably more efficient than at least my first intuitive
solution-sketch was - if interested google "roulette selection" (but
be prepared to pick through some noise), or see the course videos.
Discuss, but hold off posting any solution code til at least next Wednesday..?
-Ben
1: http://www.udacity.com/overview/Course/cs373/CourseRev/feb2012
--
You received this message because you are subscribed to the Google Groups "Melbourne Scala User Group" group.
To post to this group, send an email to scala...@googlegroups.com.
To unsubscribe from this group, send email to scala-melb+...@googlegroups.com.
For more options, visit this group at http://groups.google.com/group/scala-melb?hl=en-GB.
I wrote this solution before looking at any others:
https://gist.github.com/2033639
Was a fairly fast implementation, but didn't have time to refine it.
Looks like I went for the same concept as Jem, with spreading the
options out to occupy space in a range proportional to their weight,
then normalising the random number to that range to discover which
bucket to land it.
The ugliness in mine comes from the pick_one() method; I'm sure it
should be using one of the collections methods with an accumulator to
achieve the same result, somehow.
-Toby
Actually, I've been in Thoughtworks office in Xi'an, China for a work
trip since Sunday, and its been absolutely full on; getting used to
life (or just, "survival") in a big city here, and meeting the local
teams, and I haven't had a chance to complete my roulette selection
method.
I'll post it later next week when I get back.
-Ben
Here's my attempt without looking at the others'. I'll include my
thinking process.
First, my initial naive attempt:
// 1st cut: simple but wrong
object weightedSelection {
def apply[T](
items: Seq[WeightedItem[T]],
numSelections: Int,
r: Random): Seq[T] = {
val sortedItems = items.sortBy(_.weight)
for(i <- 1 to numSelections) yield
sortedItems.find(_.weight >= r.nextDouble).get.item
}
def count[T](items: Seq[T]) {
val grouped = items.groupBy(_.toString)
for((key, list) <- grouped)
println(key +": "+ (list.length / (items.size * 1.0)))
}
def test(numToSelect: Int) {
val weights = Seq(0.50, 0.30, 0.15, 0.05)
val items = for(w <- weights) yield WeightedItem(w.toString, w)
count(apply(items, numToSelect, new scala.util.Random))
}
}
Note test(...) just executes a test run and prints the results to
screen. It's not a unit test.
I saw a couple niggling issues straight after writing this:
1) The probability of selection is just plain wrong ... the items
aren't spaced out correctly
2) It doesn't account for duplicate probabilities, eg 2 items both at 50%
3) It doesn't account for the gap between the max probability supplied
and 1 (100%)
Issue #2 actually helped me figure out how to fix it, so my 2nd attempt:
// 2nd cut
object weightedSelection {
def apply[T](
items: Seq[WeightedItem[T]],
numSelections: Int,
r: Random): Seq[T] = {
var max = 0.0;
val spacedItems = for(i <- items.sortBy(_.weight)) yield {
max += i.weight
WeightedItem(i.item, max)
}
for(i <- 1 to numSelections) yield
spacedItems.find(_.weight >= r.nextDouble).get.item
}
... the rest are the same
}
This works assuming 0 <= weight <= 1 for each item, and that the sum
of all item weights = 1. If not, then we'll also need to divide the
weight of each spacedItem by the sum of all weights (max at the end of
the 1st for-loop).
The only niggling issue remaining - or at least annoying to me - is
the linear find(...); I test with 1000+ selections, and that's a lot
of linear searches when the items list is large. Will need to dig
through the API for a better data structure.
Incidentally, I noticed the default PRNG is a little biased: the
probability of getting 0.2 < n < 0.5 and 0.5 < n < 1 is almost the
same, as seen below. Which also means it's a little hard to write unit
tests for this as the tolerance to use will depend on the PRNG.
Comments welcome!
cheers,
King
scala> weightedSelection.test(1000000)
0.3: 0.379311
0.15: 0.189952
0.5: 0.380485
0.05: 0.050252
scala> weightedSelection.test(1000000)
0.3: 0.380466
0.15: 0.189491
0.5: 0.380332
0.05: 0.049711
scala> weightedSelection.test(1000000)
0.3: 0.378862
0.15: 0.190337
0.5: 0.380949
0.05: 0.049852
scala> weightedSelection.test(1000000)
0.3: 0.379963
0.15: 0.189682
0.5: 0.380479
0.05: 0.049876
scala> weightedSelection.test(1000000)
0.15: 0.189071
0.3: 0.381133
0.5: 0.379775
0.05: 0.050021
scala> weightedSelection.test(1000000)
0.15: 0.189668
0.3: 0.379144
0.5: 0.380937
0.05: 0.050251
scala> weightedSelection.test(10000000)
0.3: 0.3800097
0.15: 0.189868
0.5: 0.3801614
0.05: 0.0499609
scala> weightedSelection.test(10000000)
0.3: 0.3798768
0.15: 0.1900881
0.5: 0.3799624
0.05: 0.0500727
scala> weightedSelection.test(10000000)
0.3: 0.3801548
0.15: 0.1899825
0.5: 0.3798037
0.05: 0.050059
On 9 March 2012 01:02, Ben Hutchison <brhut...@gmail.com> wrote:
Thanks Andrew, you're right: r.nextDouble gets called more than once
per find(...), although the bug is a little unintuitive to me ...
> Otherwise the PRNG will run again for each comparison, leading to
I can see it doing that after putting in some println() at strategic
places ... but why does it do that? And it runs more than just 1-extra
time per comparison: varies but in general 3-times in my code.
I thought "_.weight >= r.nextDouble" would just get wrapped into a
function value that gets called once per find(...) ...
> biases in the favour of mid range probabilities.
> Better yet (in case weights do not sum to exactly 1):
> { val d = r.nextDouble()*max; spacedItems.find(_.weight >=
> d).get.item }
Yes, this fixes the biase I saw, though I don't get it: why do the
superfluous calls in the "buggy" code lead to the biase?
cheers,
King
>
>
>> The only niggling issue remaining - or at least annoying to me - is
>> the linear find(...); I test with 1000+ selections, and that's a lot
>> of linear searches when the items list is large. Will need to dig
>> through the API for a better data structure.
>
> You could use an array/IndexedSeq and a binary search for O(log L) and
> have exact results, or use John's approximate but O(1) solution.
>
> Regards,
>
> Andrew.
>
Thanks Andrew, I get it now! (though I didn't get it on the first
several readings until I tried it on a simpler example)*
I had thought f only gets invoked _once_ each time find(...) is
called, I now realise that since f is the predicate, it gets called as
_many_ times as is required per find(...) call until a matching item
is found!
ie each find(...) results in 1+ f(...)
* def findP(lim: Int)(num: Int): Boolean = {
count += 1; println(count); count > lim }
val a = 1 to 10 toArray
for(i <- 1 to 5) { count = 0; println( a.find(findP(3)) ) }
> Scala differs from most (non-functional) languages in that functions
> are easily defined inline, so the common rule "the arguments to a
> function get evaluated before the function is called" is no longer
> true and can be misleading. An even less obvious example is
> Array.fill(5)(r.nextDouble()) which will produce an array with 5
> different values.
Yes, I get all these, which is why my confusion ... until I realised I
forgot about the 1+ predicate calls in each find call. (thinking it
was 1-to-1)
> Pure functional languages have referential transparency, which means
> that the number of times it gets executed is irrelevant, so they don't
> have this source of confusion either.
>
> Scala is amazingly flexible and concise. The cost is a lot of details
> you need to understand well.
Yes!
And my maths is rusty ... I'll read the rest when my mind is clearer tomorrow!
thanks a lot, cheers,
King
>> Yes, this fixes the biase I saw, though I don't get it: why do the
>> superfluous calls in the "buggy" code lead to the biase?
>
> This is just an artifact of the mathematics. Consider a simple case of
> three elements A B C, with probabilities a, b, and c respectively. a+b
> +c=1.
>
> When we make the cumulative sum list, we get a, a+b, and 1.
>
> Now the first element, A, gets accepted if a PRNG is <a. This is
> probability a. All good so far.
>
> The second one has a (1-a) probability of being even considered, after
> which it has to pass a PRNG test with prob. a+b. So the probability
> that B is selected is (1-a)(a+b)=b+a(1-a-b)=b+ac>b. So B will be
> selected with probability greater than b. [ In the particular example
> you used, a=0.05, b=0.15 and b+a(1-a-b)=0.19, consistent with the
> observed distribution from your tests. ]
>
> In general, the middle choices will be selected more frequently than
> they should be. As you sorted the choices by probability, the middle
> choices are the mid range probabilities.
>
> Regards,
>
> Andrew.
>
Thanks Andrew, I get it now!
> In general, the middle choices will be selected more frequently than
> they should be. As you sorted the choices by probability, the middle
> choices are the mid range probabilities.
Yes, after understanding where the bug is, this now makes sense.
thanks again, cheers,
King
Maybe we can go through the various approaches taken at the meeting
tomorrow? And anything we learned from doing it?
There's things I'd change about both the implementation and test that
I'll discuss tomorrow. For now I'm out of time and tired of being
inside - Sunday arvo gardening awaits...
-Ben
package test
import scala.util.Random
import scala.collection.IterableLike
import scala.annotation.tailrec
//First, some generic infrastructure
object RichTraversable {
implicit def enrich[A](t: Iterable[A]) = new RichIterable(t)
def unfoldLeftN[A, B](seed: B, n: Int)(f: B => (B, A) ) = {
var b = seed
var result = Vector[A]()
for (i <- 0 to n) {
val pair = f(b)
b = pair._1
result = result :+ pair._2
}
result
}
}
class RichIterable[A](coll: Iterable[A]) {
@tailrec
final def foldLeftSelect[B](z: B)(op: (B, A) => B, p: B => Boolean):
(Option[A], Iterable[A]) = {
if (coll.isEmpty) return (None, Iterable.empty)
val a = coll.head
val t = coll.tail
val newZ = op(z, a)
if (p(newZ)) return (Some(a), t)
new RichIterable(t).foldLeftSelect(newZ)(op, p)
}
def cyclic: Iterable[A] = new Iterable[A] {
class CyclicIterator(traversable: Traversable[A]) extends Iterator[A] {
var currentIterator = traversable.toIterator
def next() = {
if (!currentIterator.hasNext) currentIterator = traversable.toIterator
currentIterator.next()
}
def hasNext = !traversable.isEmpty
}
def iterator = new CyclicIterator(coll)
}.view
}
import RichTraversable._
object RouletteSampling extends App {
//the correctness test
val n = 100
val colors = Vector("Red", "Green", "Blue", "Yellow")
val weights = Vector(0.05, 0.15, 0.6, 0.2)
val items = (colors, weights).zipped.map(new WeightedItem(_, _))
val selections = weightedSelection(items, n, new Random())
val expectedFreq = (colors, weights.map(_ * n)).zipped.toMap
val observedFreq = selections.groupBy((s) => s).mapValues(_.size.toDouble)
val pairwise = colors.map((color) => (expectedFreq(color),
observedFreq(color)))
val chiSquare = pairwise.map({case (e, o) => ((o - e)*(o - e))/e}).sum
val isLikely = chiSquare < 6.25
val isPlausible = chiSquare < 12.84
println("Likely: " + isLikely + " Plausible: " + isPlausible + " Chi
Square: " + chiSquare)
//the implementation itself
case class WeightedItem[T](item: T, weight: Double)
def weightedSelection[T](weightedItems: IndexedSeq[WeightedItem[T]],
n: Int, r: Random): Seq[T] = {
val weights = weightedItems.map(_.weight)
val maxIncr = weights.max * 2.0
val randomIndex = r.nextInt(weightedItems.length)
val (head, tail) = weightedItems.splitAt(randomIndex)
val initialWheel = (tail ++ head).cyclic
def randomIncr = r.nextDouble() * maxIncr
unfoldLeftN(initialWheel, n)((wheel) => {
val incr = randomIncr
val (selection, nextWheel) = wheel.foldLeftSelect(0.0)(
(cumWeight, a) => a.weight + cumWeight, (weight) => weight > incr)
(nextWheel, selection.get.item)
})
}
}
Yes, Roulette Selection is a (supposedly) fast-but-approximate method
more suitable for a larger number of selections.
Yes, I'd noticed there was some kind of horrible flaw in there causing
very slow performance for high N, but hadn't the patience to locate
it. Thanks for pinpointing it.
Two other points I want to raise tonight re: the weighted selection
1. I'm coming to the opinion that the original API signature I
specified was non-optimal, ie
def weightedSelection[T](items: Seq[WeightedItem[T]], numSelections:
Int, r: Random): Seq[T]
Rather, it should be represented as an Iterator or Stream of
selections, which you can pull on as many times as you need.
2. According to Wikipedia, the (Pearson) Chi Square test I used to
measure goodness-of-fit assumes that your selection categories have
near-equal weighting, which isn't necessarily true.
The basic problem is we know the population distribution (ie the
weightings), and we want to test how likely the observed sample (the
selections) is to have come from that population.
Can anyone suggest a more appropriate method, when the selection
categories might wildly differ in frequency?
-Ben