I did some work on retreiving the best K in Kmeans. Here's the implementation in case anyone's interrested. Feel free to integrate into Nak.
import breeze.linalg.DenseVector
import Kmeans.{Features, _}
import nak.cluster.{Kmeans => NakKmeans}
import scala.collection.immutable.IndexedSeq
import scala.collection.mutable.ListBuffer
/*
https://datasciencelab.wordpress.com/2014/01/21/selection-of-k-in-k-means-clustering-reloaded/
*/
class Kmeans(features: Features) {
def fkAlphaDispersionCentroids(k: Int, dispersionOfKMinus1: Double = 0d, alphaOfKMinus1: Double = 1d): (Double, Double, Double, Features) = {
if (1 == k || 0d == dispersionOfKMinus1) (1d, 1d, 1d, Vector.empty)
else {
val featureDimensions = features.headOption.map(_.size).getOrElse(1)
val (dispersion, centroids: Features) = new NakKmeans[DenseVector[Double]](features).run(k)
val alpha =
if (2 == k) 1d - 3d / (4d * featureDimensions)
else alphaOfKMinus1 + (1d - alphaOfKMinus1) / 6d
val fk = dispersion / (alpha * dispersionOfKMinus1)
(fk, alpha, dispersion, centroids)
}
}
def fks(maxK: Int = maxK): List[(Double, Double, Double, Features)] = {
val fadcs = ListBuffer[(Double, Double, Double, Features)](fkAlphaDispersionCentroids(1))
var k = 2
while (k <= maxK) {
val (fk, alpha, dispersion, features) = fadcs(k - 2)
fadcs += fkAlphaDispersionCentroids(k, dispersion, alpha)
k += 1
}
fadcs.toList
}
def detK: (Double, Features) = {
val vals = fks().minBy(_._1)
(vals._3, vals._4)
}
}
object Kmeans {
val maxK = 10
type Features = IndexedSeq[DenseVector[Double]]
}
Here's a simple test:
import breeze.linalg.DenseVector
object GeographicClustering extends App {
def makeGroup(min: Int, max: Int): Vector[DenseVector[Double]] =
min to max flatMap (x => min to max map (x -> _)) map { case (x, y) => DenseVector(x.toDouble, y.toDouble) } toVector
val points: Vector[DenseVector[Double]] = makeGroup(0, 2) ++ makeGroup(6, 7) ++ makeGroup(12, 13)
println(points.map(_.toArray.mkString("[", ",", "]")).toArray.mkString("array([", ",", "])"))
val kmeans = new Kmeans(points)
val x = kmeans.detK
println(x)
}
Regards, Eirik