Hi all,
I've written conjugate gradient (CG) algorithm to iteratively solve symmetric linear equations (especially efficient for sparse systems). One beauty of CG is that it doesn't really need a matrix: all it needs is a product of the matrix by a supplied vector. Thus, "matrix" (or rather "linear operator") can be implemented as a function (essentially `mmul` "closed" over matrix). Here is an example of the matrix representing the discretization of the second derivative (for example, diffusive term in mass balance) with `d` being diffusion coefficient and `dx` is the step of the mesh.
(defn diff-matrix [^long n]
(let [n-1 (long (dec n))]
(fn [x b]
(let [d (double 1e-9)
dx (double 1e-2)
coeff (double (/ d dx dx))]
(m/mset! b 0 (m/mget x 0))
(loop [i (long 1)]
(cond (= i n-1) (m/mset! b i (m/mget x i))
:else (do (m/mset! b i
^double
(* coeff
^double
(+ ^double (m/mget x (dec i))
^double (m/mget x (inc i))
^double (* -2.0 ^double (m/mget x i)))))
(recur ^long (inc i)))))))))
The function is closed over `n` - the size of the problem, the main part is within `fn` form. It produces the result of multiplication `A*x` by putting the result into `b`. Couple of points:
- I am using mutation to avoid allocation. Allocation is not a problem on its own, but if there is too much allocation, GC needs to be fired too often and that may slow down the computation (especially if linear solver is a part of non-linear solver, which is a part of ODE solver...)
- I am using direct `loop`-`recur`: I've tried `map-indexed!`, but it was quite slow.
- I am using Vectorz as implementation; I set `*warn-on-reflection*` to true`.
The problem is that the solver is rather slow: on 1000-element system it takes ~240 ms to solve (for comparison, my Common Lisp implementation takes ~50 ms with vector operations not always optimised for double-float arithmetic). Profiling with `timbre` showed that the slowest step is the "matrix multiplication", i.e. application `(diff-matrix x b)` (takes about 90% of the total time).
Question is why it is so slow. In my attempt to optimize `diff-matrix` I've put type hints everywhere I could, but now I'm stuck. Is there a better way to do the loop? Is it arithmetic that is slow, or `mget`/`mset!`? I can avoid the loop, if, for example, I operate on vector slices (can I? considering the mutation of `b`) and set the first and last elements separately.
PS. My aim is to bring this to <100 ms to consider the method/implementation viable.
Cheers, Alexey