from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
setup(ext_modules = cythonize(Extension(
"pythonznn", # the extesion name
# the Cython source and additional C++ source files
sources=["pyznn.pyx", "znn.cxx"],
language="c++",
# generate and compile C++ code
include_dirs=["../", "../../src", "../../zi", "/usr/people/jingpeng/libs/boost/include/"],
libraries=["stdc++", "fftw3", "pthread", "rt", "fftw3_threads", "boost_program_options", "boost_regex", "boost_filesystem", "boost_system", "boost_timer"],
extra_compile_args=['-g'],
extra_link_args=['-g', "-L../../ -L/usr/people/jingpeng/libs/boost/lib64/"]
)))
import numpy as np
cimport numpy as np
cdef extern from "znn.cxx":
void znn_forward( double* input_py, unsigned int iz, unsigned int iy, unsigned int ix,\
double* output_py, unsigned int oz, unsigned int oy, unsigned int ox)
def run_forward(np.ndarray[double, ndim=3, mode="c"] input not None,\
np.ndarray[double, ndim=3, mode="c"] output not None):
"""
run the forward pass of znn
"""
cdef int iz = input.shape[0]
cdef int iy = input.shape[1]
cdef int ix = input.shape[2]
cdef int oz = output.shape[0]
cdef int oy = output.shape[1]
cdef int ox = output.shape[2]
# cdef string config_fpath = "forward.config"
znn_forward(&input[0,0,0], iz, iy, ix,
&output[0,0,0], oz, oy, ox )
#ifndef ZNN_CXX_INCLUDED
#define ZNN_CXX_INCLUDED
#include "../core/network.hpp"
#include "../front_end/options.hpp"
#include <zi/zargs/zargs.hpp>
namespace z = zi::znn;
using namespace zi::znn;
ZiARG_string(options, "", "Option file path");
ZiARG_bool(test_only, true, "Test only");
void znn_forward( double* input_py, unsigned int iz, unsigned int iy, unsigned int ix,
double* output_py, unsigned int oz, unsigned int oy, unsigned int ox)
{
// options, create fake main function parameters
char *argv[] = { "run_znn_forward", "--options=forward.config", NULL };
int argc=2;
zi::parse_arguments(argc, argv);
options_ptr op = options_ptr(new options(ZiARG_options));
op->save();
// create network
network net(op);
// initialization
double3d_ptr pinput = volume_pool.get_double3d(ix,iy,iz);
double3d& input = *pinput;
// input.data() = input_array.get_data();
int index = 0;
for (std::size_t z=0; z<iz; z++)
for (std::size_t y=0; y<iy; y++)
for (std::size_t x=0; x<ix; x++)
{
input[x][y][z] = input_py[index];
index++;
}
std::list<double3d_ptr> pinputs;
pinputs.push_back(pinput);
// prepare output
std::list<double3d_ptr> poutputs;
poutputs = net.run_forward(pinputs);
double3d& output = *(poutputs.front());
// give value to python numpy output array
index = 0;
for (std::size_t z=0; z<oz; z++)
for (std::size_t y=0; y<oy; y++)
for (std::size_t x=0; x<ox; x++)
{
output_py[index] = output[x][y][z];
index++;
}
}
#endif // ZNN_CXX_INCLUDED