Proposed solution for a squared distance matrix + merge request

72 views
Skip to first unread message

Daniel Korzekwa

unread,
Sep 9, 2015, 10:21:52 AM9/9/15
to Scala Breeze
Could we merge a function for computing squared distance matrix to breeze?

impl:

test:


My implementation is based on sq_dist() from gpml matlab library, but using broadcasted rows instead of bsxfun. Although, my current impl is much faster than iterating over all cells of distance matrix, it is still about 3 times slower than sq_dist() in octave.

What can we do in order to match up the performance of this function in octave? and why is it 3 times slower?

PS.
Happy to transform it to ufunc and provide a pull request, just please create a skeleton file for this function in breeze.
My ufunc() impl works with 0.12-SNAPSHOT only, broadcasting is not working with breeze 0.11.2 (some compilation errors, missing implicits)
Somehow it is about 30% slower when using natives.

thanks.

David Hall

unread,
Sep 10, 2015, 3:13:25 AM9/10/15
to scala-...@googlegroups.com
Thanks!

We'd have to do some careful profiling. Could you run it through hprof (java -Xrunhprof:cpu=samples,depth=12) and share the results? It'd be even more amazing if you could clock the individual pieces of the matlab implementation (i.e. how much time is spent on each line), just to know where the ceiling is

I have started profiling when it's actually better to use natives/netlib-java, and I'm increasingly finding it's not worth it except for rather large inputs. A careful java implementation gets most of the performance, especially given the overhead. As an example: I sped up short dense dot products (~8 elements) by a factor of 100(!) by some careful profiling/jvm engineering. It took vectors of length 200 to see benefits from using netlib-java.



--
You received this message because you are subscribed to the Google Groups "Scala Breeze" group.
To unsubscribe from this group and stop receiving emails from it, send an email to scala-breeze...@googlegroups.com.
To post to this group, send email to scala-...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/scala-breeze/5d7a995e-fc2d-44e0-bb11-afbdfd79cc62%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

David Hall

unread,
Sep 10, 2015, 3:17:11 AM9/10/15
to scala-...@googlegroups.com
as for the ufunc:

object sqDist extends UFunc {
implicit object implDMDM extends Impl2[DenseMatrix[Double], DenseMatrix[Double], DenseMatrix[Double]] {
def apply(x1: DenseMatrix[Double], x2: DenseMatrix[Double]): DenseMatrix[Double] = ???
}
}

Daniel Korzekwa

unread,
Sep 10, 2015, 7:18:07 AM9/10/15
to Scala Breeze
Octave code + profiling:

 x = [1:4000];
 a = x;
 b = x;
 profile on
 bsxfun(@plus,sum(a.*a,1)',bsxfun(@minus,sum(b.*b,1),2*a'*b));
 profile off
 profshow(profile('info'))
   #            Function Attr     Time (s)        Calls
-------------------------------------------------------
   5              bsxfun             0.111            2
   4            binary *             0.049            2
   1           binary .*             0.001            2
   2                 sum             0.000            2
   3           postfix '             0.000            2
   6             profile             0.000            1
   7              nargin             0.000            1
   8           binary !=             0.000            1
   9               false             0.000            1
  10 __profiler_enable__             0.000            1

second run:
profshow(profile('info'))
   #            Function Attr     Time (s)        Calls
-------------------------------------------------------
   5              bsxfun             0.123            2
   4            binary *             0.095            2
   1           binary .*             0.000            2
   2                 sum             0.000            2
   3           postfix '             0.000            2
   6             profile             0.000            1
   7              nargin             0.000            1
   8           binary !=             0.000            1
   9               false             0.000            1
  10 __profiler_enable__             0.000            1

Daniel Korzekwa

unread,
Sep 10, 2015, 7:25:55 AM9/10/15
to Scala Breeze
scala + profiling

 @Test def test_david = {
    val x = DenseVector.rangeD(0, 4000, 1).toDenseMatrix

    val now = System.currentTimeMillis();
    sqDist(x, x)
    println(System.currentTimeMillis()-now)
  }

----------------
java -Xrunhprof:cpu=samples,depth=12

CPU SAMPLES BEGIN (total = 341) Thu Sep 10 12:23:15 2015
rank   self  accum   count trace method
   1 48.68% 48.68%     166 300156 java.net.SocketInputStream.socketRead0
   2  7.33% 56.01%      25 300096 java.lang.ClassLoader.defineClass1
   3  5.57% 61.58%      19 300607 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
   4  3.23% 64.81%      11 300591 org.netlib.blas.Dgemm.dgemm
   5  3.23% 68.04%      11 300605 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
   6  2.05% 70.09%       7 300470 java.util.zip.Inflater.inflateBytes
   7  1.76% 71.85%       6 300477 java.util.zip.ZipFile.read
   8  1.76% 73.61%       6 300611 breeze.linalg.NumericOps$class.$colon$eq
   9  1.17% 74.78%       4 300593 breeze.linalg.operators.DenseMatrixOps$$anon$135.apply
  10  0.88% 75.66%       3 300614 java.lang.Class.getComponentType
  11  0.88% 76.54%       3 300475 java.util.zip.ZipFile.getEntry
  12  0.88% 77.42%       3 300610 java.lang.Class.getComponentType
  13  0.88% 78.30%       3 300072 java.io.WinNTFileSystem.getBooleanAttributes
  14  0.59% 78.89%       2 300602 breeze.linalg.DenseVector$canDaxpy$.apply
  15  0.59% 79.47%       2 300490 java.util.zip.ZipFile.getEntry
  16  0.29% 79.77%       1 300603 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
  17  0.29% 80.06%       1 300601 scala.reflect.ManifestFactory$$anon$12.newArray
  18  0.29% 80.35%       1 300599 breeze.linalg.DenseVector.update
  19  0.29% 80.65%       1 300598 breeze.linalg.DenseMatrix$$anon$22.apply
  20  0.29% 80.94%       1 300606 scala.runtime.ScalaRunTime$.array_length
  21  0.29% 81.23%       1 300592 scala.reflect.ManifestFactory$$anon$12.newArray
  22  0.29% 81.52%       1 300608 breeze.linalg.DenseMatrix$$anon$37$$anonfun$apply$6.apply
  23  0.29% 81.82%       1 300590 org.netlib.blas.Dgemm.dgemm
  24  0.29% 82.11%       1 300589 java.lang.Class.forName0
  25  0.29% 82.40%       1 300535 scala.reflect.ManifestFactory$$anon$12.newArray
  26  0.29% 82.70%       1 300534 dk.gp.cov.CovSEisoTest.test_david
  27  0.29% 82.99%       1 300529 sun.util.calendar.ZoneInfo.getTransitionIndex
  28  0.29% 83.28%       1 300528 java.lang.Object.hashCode
  29  0.29% 83.58%       1 300028 java.lang.AbstractStringBuilder.append
  30  0.29% 83.87%       1 300527 java.util.HashMap.hash
  31  0.29% 84.16%       1 300609 breeze.linalg.DenseVector$$anon$2.apply
  32  0.29% 84.46%       1 300525 breeze.linalg.operators.MatrixOps$class.$init$
  33  0.29% 84.75%       1 300524 breeze.linalg.operators.MatrixOps$class.$init$
  34  0.29% 85.04%       1 300523 breeze.linalg.operators.MatrixOps$class.$init$
  35  0.29% 85.34%       1 300522 java.io.WinNTFileSystem.getBooleanAttributes
  36  0.29% 85.63%       1 300521 breeze.linalg.operators.MatrixOps$$anon$54.<init>
  37  0.29% 85.92%       1 300520 breeze.linalg.DenseVector$mcD$sp.asDenseMatrix$mcD$sp
  38  0.29% 86.22%       1 300519 java.lang.Throwable.fillInStackTrace
  39  0.29% 86.51%       1 300516 java.lang.CharacterDataLatin1.toUpperCase
  40  0.29% 86.80%       1 300616 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
  41  0.29% 87.10%       1 300511 breeze.linalg.DenseVector$.canDim
  42  0.29% 87.39%       1 300510 java.security.AccessController.doPrivileged
  43  0.29% 87.68%       1 300509 java.io.WinNTFileSystem.getBooleanAttributes
  44  0.29% 87.98%       1 300501 scala.collection.immutable.Seq$.newBuilder
  45  0.29% 88.27%       1 300604 scala.collection.immutable.Range.foreach
  46  0.29% 88.56%       1 300489 breeze.linalg.operators.DenseVectorOps$class.$init$
  47  0.29% 88.86%       1 300488 sun.util.calendar.AbstractCalendar.getTime
  48  0.29% 89.15%       1 300487 java.io.WinNTFileSystem.getBooleanAttributes
  49  0.29% 89.44%       1 300486 java.lang.ClassLoader.loadClass
  50  0.29% 89.74%       1 300485 scala.collection.mutable.HashTable$class.scala$collection$mutable$HashTable$$addEntry0
  51  0.29% 90.03%       1 300484 breeze.linalg.operators.DenseVectorOps$$anon$13.<init>
  52  0.29% 90.32%       1 300483 breeze.linalg.operators.DenseVectorOps$class.$init$
  53  0.29% 90.62%       1 300612 scala.reflect.ManifestFactory$$anon$12.newArray
  54  0.29% 90.91%       1 300613 scala.collection.immutable.Range.foreach
  55  0.29% 91.20%       1 300474 java.io.WinNTFileSystem.getBooleanAttributes
  56  0.29% 91.50%       1 300473 java.lang.String.substring
  57  0.29% 91.79%       1 300472 breeze.linalg.VectorOps$$anon$56.<init>
  58  0.29% 92.08%       1 300471 breeze.linalg.VectorOps$class.$init$
  59  0.29% 92.38%       1 300526 breeze.linalg.operators.MatrixOps$class.$init$
  60  0.29% 92.67%       1 300440 scala.package$.<init>
  61  0.29% 92.96%       1 300060 java.lang.System.arraycopy
  62  0.29% 93.26%       1 300402 scala.Option$.<clinit>
  63  0.29% 93.55%       1 300360 scala.collection.mutable.AbstractMap.<init>
  64  0.29% 93.84%       1 300330 breeze.generic.MMRegistry2$class.$init$
  65  0.29% 94.13%       1 300326 breeze.linalg.DenseVector$.<init>
  66  0.29% 94.43%       1 300324 java.io.WinNTFileSystem.getBooleanAttributes
  67  0.29% 94.72%       1 300322 java.net.URLClassLoader.findClass
  68  0.29% 95.01%       1 300321 java.lang.Throwable.fillInStackTrace
  69  0.29% 95.31%       1 300318 org.junit.runners.BlockJUnit4ClassRunner.testName
  70  0.29% 95.60%       1 300311 java.util.Arrays.sort
  71  0.29% 95.89%       1 300617 breeze.linalg.NumericOps$class.$colon$eq
  72  0.29% 96.19%       1 300308 java.lang.System.arraycopy
  73  0.29% 96.48%       1 300615 breeze.linalg.DenseVector.copy
  74  0.29% 96.77%       1 300306 java.io.WinNTFileSystem.getBooleanAttributes
  75  0.29% 97.07%       1 300281 sun.reflect.annotation.AnnotationParser$1.run
  76  0.29% 97.36%       1 300258 java.net.URLStreamHandler.setURL
  77  0.29% 97.65%       1 300247 java.io.WinNTFileSystem.getBooleanAttributes
  78  0.29% 97.95%       1 300192 sun.util.locale.provider.JRELocaleProviderAdapter.getLanguageTagSet
  79  0.29% 98.24%       1 300512 java.lang.ClassLoader.findBootstrapClass
  80  0.29% 98.53%       1 300155 java.lang.System.nanoTime
  81  0.29% 98.83%       1 300146 java.net.DualStackPlainSocketImpl.connect0
  82  0.29% 99.12%       1 300110 java.lang.ClassLoader$NativeLibrary.load
  83  0.29% 99.41%       1 300101 java.io.WinNTFileSystem.getBooleanAttributes
  84  0.29% 99.71%       1 300083 java.util.zip.Inflater.inflateBytes
  85  0.29% 100.00%       1 300618 breeze.linalg.DenseMatrix$$anon$35$$anonfun$apply$4.apply
CPU SAMPLES END
-------------

On Thursday, 10 September 2015 08:13:25 UTC+1, David Hall wrote:

Daniel Korzekwa

unread,
Sep 10, 2015, 7:31:02 AM9/10/15
to Scala Breeze
profile for running scala code ( 100x loop)
 @Test def test_david = {
    val x = DenseVector.rangeD(0, 4000, 1).toDenseMatrix

    val now = System.currentTimeMillis();
    for(i<- 1 to 100) sqDist(x, x)
    println(System.currentTimeMillis()-now)
  }

CPU SAMPLES BEGIN (total = 19555) Thu Sep 10 12:30:07 2015
rank   self  accum   count trace method
   1 49.96% 49.96%    9769 300162 java.net.SocketInputStream.socketRead0
   2 18.84% 68.80%    3685 300628 scala.runtime.BoxesRunTime.boxToDouble
   3 10.28% 79.08%    2010 300632 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
   4  9.75% 88.83%    1907 300631 java.lang.Number.<init>
   5  1.59% 90.42%     310 300602 breeze.linalg.operators.DenseMatrixOps$$anon$135.apply
   6  1.22% 91.63%     238 300636 breeze.linalg.DenseVector$canDaxpy$.apply
   7  1.03% 92.67%     202 300639 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
   8  0.93% 93.59%     181 300629 scala.runtime.BoxesRunTime.unboxToDouble
   9  0.64% 94.24%     126 300633 breeze.linalg.BroadcastedRows$$anon$3$$anonfun$apply$3.apply
  10  0.64% 94.88%     126 300641 scala.collection.immutable.Range.foreach
  11  0.52% 95.40%     102 300621 scala.reflect.ManifestFactory$$anon$12.newArray
  12  0.52% 95.92%     101 300608 scala.reflect.ManifestFactory$$anon$12.newArray
  13  0.51% 96.43%     100 300541 scala.reflect.ManifestFactory$$anon$12.newArray
  14  0.50% 96.93%      97 300600 scala.reflect.ManifestFactory$$anon$12.newArray
  15  0.49% 97.42%      96 300642 com.github.fommil.netlib.F2jBLAS.dgemm
  16  0.43% 97.85%      84 300637 breeze.linalg.BroadcastedColumns$$anon$3$$anonfun$apply$3.apply
  17  0.27% 98.11%      52 300643 scala.runtime.AbstractFunction0.<init>
  18  0.25% 98.36%      49 300644 scala.runtime.AbstractFunction0.<init>
  19  0.20% 98.57%      40 300599 org.netlib.blas.Dgemm.dgemm
  20  0.10% 98.67%      19 300101 java.lang.ClassLoader.defineClass1
  21  0.10% 98.76%      19 300649 breeze.linalg.DenseVector.copy
  22  0.08% 98.84%      16 300618 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
  23  0.08% 98.92%      15 300617 scala.collection.immutable.Range.foreach
  24  0.07% 98.99%      14 300640 breeze.linalg.DenseMatrix$$anon$33.apply
  25  0.06% 99.05%      11 300614 breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply
  26  0.06% 99.11%      11 300263 java.util.zip.Inflater.inflateBytes
  27  0.05% 99.16%      10 300634 scala.collection.immutable.Range.foreach
  28  0.05% 99.21%      10 300630 scala.collection.immutable.Range.foreach
  29  0.05% 99.26%      10 300620 breeze.linalg.DenseMatrix$$anon$37$$anonfun$apply$6.apply
  30  0.05% 99.30%       9 300650 java.lang.Double.valueOf
  31  0.04% 99.34%       7 300655 breeze.linalg.DenseVector$canDaxpy$.apply
  32  0.03% 99.37%       6 300647 scala.collection.immutable.Range.foreach
  33  0.03% 99.40%       6 300638 scala.collection.immutable.Range.foreach
  34  0.02% 99.42%       4 300656 java.lang.Double.valueOf
  35  0.02% 99.44%       3 300619 java.lang.Class.getComponentType
  36  0.02% 99.45%       3 300654 breeze.linalg.DenseVector.copy
  37  0.02% 99.47%       3 300624 java.lang.Class.getComponentType
  38  0.02% 99.48%       3 300323 java.util.zip.ZipFile.getEntry
  39  0.01% 99.49%       2 300525 java.util.zip.ZipFile.read
  40  0.01% 99.50%       2 300651 breeze.linalg.operators.DenseVector_GenericOps$class.implOpSet_DV_DV_InPlace
  41  0.01% 99.51%       2 300635 scala.reflect.ManifestFactory$$anon$12.newArray
  42  0.01% 99.52%       2 300073 java.util.zip.ZipFile.open
CPU SAMPLES END

David Hall

unread,
Sep 10, 2015, 1:39:18 PM9/10/15
to scala-...@googlegroups.com
could you attach the entire .txt file? It's useful for figuring out what e.g. breeze.linalg.operators.DenseVector_GenericOps$$anon$311.apply means

--
You received this message because you are subscribed to the Google Groups "Scala Breeze" group.
To unsubscribe from this group and stop receiving emails from it, send an email to scala-breeze...@googlegroups.com.
To post to this group, send email to scala-...@googlegroups.com.

David Hall

unread,
Sep 10, 2015, 1:39:36 PM9/10/15
to scala-...@googlegroups.com
(a lot of boxing, which is good, since i can usually eliminate that.)

David Hall

unread,
Sep 10, 2015, 2:33:42 PM9/10/15
to scala-...@googlegroups.com
figured out the boxing. I can probably eliminate it tonight.

On Thu, Sep 10, 2015 at 4:31 AM, Daniel Korzekwa <daniel....@gmail.com> wrote:

--
You received this message because you are subscribed to the Google Groups "Scala Breeze" group.
To unsubscribe from this group and stop receiving emails from it, send an email to scala-breeze...@googlegroups.com.
To post to this group, send email to scala-...@googlegroups.com.

Daniel Korzekwa

unread,
Sep 11, 2015, 4:45:59 AM9/11/15
to Scala Breeze
attaching complete hprof file for:

  @Test def test_david = {
    val x = DenseVector.rangeD(0, 4000, 1).toDenseMatrix

    val now = System.currentTimeMillis();
    for(i<- 1 to 100) sqDist(x, x)
    println(System.currentTimeMillis()-now)
  }

On Wednesday, 9 September 2015 15:21:52 UTC+1, Daniel Korzekwa wrote:
java.hprof.txt

David Hall

unread,
Sep 11, 2015, 12:20:23 PM9/11/15
to scala-...@googlegroups.com
try the newest snapshot. it should be about 2-3x faster

--
You received this message because you are subscribed to the Google Groups "Scala Breeze" group.
To unsubscribe from this group and stop receiving emails from it, send an email to scala-breeze...@googlegroups.com.
To post to this group, send email to scala-...@googlegroups.com.

David Hall

unread,
Sep 11, 2015, 12:21:09 PM9/11/15
to scala-...@googlegroups.com
also the test example (a 4000x1 matrix) is maximally bad for the way I've implemented broadcasting, and not terribly realistic. Maybe do a 200x20 and/or a 20x200?

-- David

Daniel Korzekwa

unread,
Sep 11, 2015, 1:14:19 PM9/11/15
to Scala Breeze
thanks!

I guess you mean mean [1 row x 4000 col ] instead [4000x1]. Realistic data is [5x200] - 5 predictor features and 200 inducing points, it gives [200x200] sq dist matrix.

ok, results for two scenarios: david and david2:

 @Test def test_david = {
    val x = DenseVector.rangeD(0, 4000, 1).toDenseMatrix

    val now = System.currentTimeMillis();
    for (i <- 1 to 10) sqDist(x, x)
    println("[1x4000]" + (System.currentTimeMillis() - now))
  }
  
   @Test def test_david2 = {
    val x = DenseMatrix.rand(5,200)

    val now = System.currentTimeMillis();
    for (i <- 1 to 1000) sqDist(x, x)
    println("[5x200]" + (System.currentTimeMillis() - now))
  }

running every test three times for old and new breeze snapshot, Times in millis

david_old,  david_new,  david2_old,  david2_new
-------------------------------------------------------------------
12224,      8164,           2765,          1648
12203,      8500,           2802,          1995
12417,      8562,           2874,          1796

David Hall

unread,
Sep 11, 2015, 1:25:57 PM9/11/15
to scala-...@googlegroups.com
how's that compare to octave?

David Hall

unread,
Sep 11, 2015, 1:26:34 PM9/11/15
to scala-...@googlegroups.com
i'd wager 2x slower for the david_old?

Daniel Korzekwa

unread,
Sep 11, 2015, 1:52:42 PM9/11/15
to Scala Breeze
octave code :  tic;for i= 1:1000 bsxfun(@plus,sum(a.*a,1)',bsxfun(@minus,sum(b.*b,1),2*a'*b)); end;toc

[1x4000] - 2 seconds //4 times faster than david_new
[5x200] - 0.367 sec // 4.5 times faster than david2_new

David Hall

unread,
Sep 11, 2015, 1:57:54 PM9/11/15
to scala-...@googlegroups.com
sigh, lots of work to do

David Hall

unread,
Sep 12, 2015, 1:39:41 AM9/12/15
to scala-...@googlegroups.com
ok, this vectorized implementation is roughly 2x faster than the one using broadcasting. It spends roughly 30% of its time doing matrix multiply, 11% doing addition and subtraction, and the rest I can't figure out. This could probably be improved further by rolling better loops for broadcasted add, but i think i'll stop for now.  

val D = x1.rows
val M = x1.cols
val N = x2.cols
val x1OnesT = DenseMatrix.ones[Double](M, D)
val x2Ones = DenseMatrix.ones[Double](D, N)
// (M x D) * D x N -= M X N += (D x M).t * (D X N)
(x1 :* x1).t * x2Ones -= (x1.t * x2 *= 2.0) += x1OnesT * (x2 :* x2)

-- David
Reply all
Reply to author
Forward
0 new messages