#include <iostream>
#include <iomanip>
#include <slate/slate.hh>
#include <slateUtils.hh>
template <typename scalar_t>
void populate(slate::Matrix<scalar_t> &A) {
int64_t m = A.m();
int64_t n = A.n();
int64_t mt = A.mt();
int64_t nt = A.nt();
int64_t nn = 0, mm = 0, nb, mb;
for (int64_t j = 0; j < nt; j++) {
nb = A.tileNb(j);
mm = 0;
for (int64_t i = 0; i < mt; i++) {
mb = A.tileMb(j);
if (A.tileIsLocal(i, j)) {
A.tileGetForWriting(i, j, slate::LayoutConvert::RowMajor);
auto tile = A(i, j);
auto tiledata = tile.data();
int64_t stride = tile.stride();
for (int64_t ii = 0; ii < mb; ii++) {
for (int64_t jj = 0; jj < nb; jj++) {
int64_t kk = jj + ii * stride;
tiledata[kk] = (mm + ii) * n + nn + jj;
}
}
}
mm += mb;
}
nn += nb;
}
A.tileUpdateAllOrigin();
A.releaseWorkspace();
}
int main(int argc, char** argv) {
// Initialize MPI, require MPI_THREAD_MULTIPLE support
int err=0, mpi_provided=0;
err = MPI_Init_thread( &argc, &argv, MPI_THREAD_MULTIPLE, &mpi_provided );
assert( err == 0 && mpi_provided == MPI_THREAD_MULTIPLE );
int64_t m = 10;
int64_t n = m;
int64_t nb = 4;
int64_t p = 2;
int64_t q = 2;
double eps = std::numeric_limits<double>::epsilon();
double beta = std::sqrt(eps);
slate::Matrix<double> A(m + n, n, nb, p, q, MPI_COMM_WORLD );
A.insertLocalTiles();
auto H = A.slice(0, m-1, 0, n-1);
auto R = A.slice(m, m+n-1, 0, n-1);
populate(A);
double gamma = std::sqrt(m * beta / n);
std::vector<double> diag(n, gamma);
slateUtils::matrix::diagonal(R, diag); // Sets diag as the diagonal, 0 elsewhere
slate::Matrix<double> B(m+n, 1, nb, p, q, MPI_COMM_WORLD );
B.insertLocalTiles();
auto Y = B.slice(0, m-1, 0, 0);
auto YR = B.slice(m, m+n-1, 0, 0);
populate(Y);
slateUtils::matrix::zero(YR); // Sets all elements to zero
slateUtils::matrix::csv::write(A, "outA0.csv", ',', false); // Writes matrix to .csv file
slateUtils::matrix::csv::write(B, "outB0.csv", ',', false);
slate::gels<double>(A, B);
slateUtils::matrix::csv::write(A, "outA1.csv", ',', false);
slateUtils::matrix::csv::write(B, "outB1.csv", ',', false);
return 0;
}