from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
setup(ext_modules = cythonize(Extension(
           "pythonznn",                                # the extesion name
           # the Cython source and additional C++ source files
           sources=["pyznn.pyx", "znn.cxx"],
           language="c++",
           # generate and compile C++ code
           include_dirs=["../", "../../src", "../../zi", "/usr/people/jingpeng/libs/boost/include/"],
           libraries=["stdc++", "fftw3", "pthread", "rt", "fftw3_threads", "boost_program_options", "boost_regex", "boost_filesystem", "boost_system", "boost_timer"],
           extra_compile_args=['-g'],
           extra_link_args=['-g', "-L../../ -L/usr/people/jingpeng/libs/boost/lib64/"]
      )))
import numpy as np
cimport numpy as np
cdef extern from "znn.cxx":
    void znn_forward(   double* input_py,  unsigned int iz, unsigned int iy, unsigned int ix,\
                        double* output_py, unsigned int oz, unsigned int oy, unsigned int ox)
def run_forward(np.ndarray[double, ndim=3, mode="c"] input  not None,\
                np.ndarray[double, ndim=3, mode="c"] output not None):
    """
    run the forward pass of znn
    """
    cdef int iz = input.shape[0]
    cdef int iy = input.shape[1]
    cdef int ix = input.shape[2]
    cdef int oz = output.shape[0]
    cdef int oy = output.shape[1]
    cdef int ox = output.shape[2]
    # cdef string config_fpath = "forward.config"
    znn_forward(&input[0,0,0],  iz, iy, ix,
                &output[0,0,0], oz, oy, ox )
#ifndef ZNN_CXX_INCLUDED
#define ZNN_CXX_INCLUDED
#include "../core/network.hpp"
#include "../front_end/options.hpp"
#include <zi/zargs/zargs.hpp>
namespace z  = zi::znn;
using namespace zi::znn;
ZiARG_string(options, "", "Option file path");
ZiARG_bool(test_only, true, "Test only");
void znn_forward(   double* input_py,  unsigned int iz, unsigned int iy, unsigned int ix,
                    double* output_py, unsigned int oz, unsigned int oy, unsigned int ox)
{
    // options, create fake main function parameters
    char *argv[] = { "run_znn_forward", "--options=forward.config", NULL };
    int argc=2;
    zi::parse_arguments(argc, argv);
    options_ptr op = options_ptr(new options(ZiARG_options));
    op->save();
    // create network
    network net(op);
    // initialization
    double3d_ptr pinput = volume_pool.get_double3d(ix,iy,iz);
    double3d& input = *pinput;
    // input.data() = input_array.get_data();
    int index = 0;
    for (std::size_t z=0; z<iz; z++)
        for (std::size_t y=0; y<iy; y++)
            for (std::size_t x=0; x<ix; x++)
            {
                input[x][y][z] = input_py[index];
                index++;
            }
    std::list<double3d_ptr> pinputs;
    pinputs.push_back(pinput);
    // prepare output
    std::list<double3d_ptr> poutputs;
    poutputs = net.run_forward(pinputs);
    double3d& output = *(poutputs.front());
    // give value to python numpy output array
    index = 0;
    for (std::size_t z=0; z<oz; z++)
        for (std::size_t y=0; y<oy; y++)
            for (std::size_t x=0; x<ox; x++)
            {
                output_py[index] = output[x][y][z];
                index++;
            }
}
#endif // ZNN_CXX_INCLUDED