diff --git a/compyle/array.py b/compyle/array.py index 172add4..496eda8 100644 --- a/compyle/array.py +++ b/compyle/array.py @@ -1,11 +1,10 @@ import numpy as np import math import mako.template as mkt -import time from pytools import memoize, memoize_method from .config import get_config -from .types import (annotate, dtype_to_ctype, ctype_to_dtype, declare, +from .types import (annotate, dtype_to_ctype, declare, dtype_to_knowntype, knowntype_to_ctype) from .template import Template from .sort import radix_sort @@ -394,7 +393,6 @@ def linspace(start, stop, num, dtype=np.float64, backend='opencl', out = out * delta+start elif backend == 'cuda': import pycuda.gpuarray as gpuarray - import pycuda.autoinit if endpoint: delta = (stop-start)/(num-1) else: @@ -445,7 +443,6 @@ def diff(a, n, backend=None): backend = a.backend if backend == 'opencl' or backend == 'cuda': - from compyle.api import Elementwise binom_coeff = np.zeros(n+1) sign_fac = 1 if (n % 2 == 0) else -1 for i in range(n+1): @@ -526,6 +523,7 @@ def trapz(y, x=None, dx=1.0, backend=None): out = dot(d, sum_ar) * 0.5 return out + @annotate def where_elwise(i, condition, x, y, ans): if condition[i]: @@ -872,7 +870,6 @@ def comparison_kernel(func, backend, ary_type, other_type): def comparison_template(func, other, arr, backend=None): if backend is None: backend = arr.backend - from compyle.parallel import Elementwise other_type = dtype_to_ctype(type(other)) ary_type = dtype_to_ctype(arr.dtype) + 'p' ans = empty(arr.length, dtype=np.int32, backend=arr.backend) @@ -1023,7 +1020,7 @@ def get_buff(self, offset=0, length=0): return cu_bufint(self._data, nbytes, int(offset)) def get(self): - if self.backend == 'cython': + if self.backend == 'cython' or self.backend == 'c': return self.dev elif self.backend == 'opencl' or self.backend == 'cuda': return self.dev.get() diff --git a/compyle/c_backend.py b/compyle/c_backend.py new file mode 100644 index 0000000..bdabff6 --- /dev/null +++ b/compyle/c_backend.py @@ -0,0 +1,322 @@ +from compyle.profile import profile +from .translator import ocl_detect_type, KnownType +from .cython_generator import CythonGenerator, get_func_definition +from .cython_generator import getsourcelines +from mako.template import Template +from .ext_module import get_md5 +from .cimport import Cmodule +from .transpiler import Transpiler +from . import array + +import pybind11 +import numpy as np + + +elwise_c_pybind = ''' + +PYBIND11_MODULE(${modname}, m) { + + m.def("${modname}", [](${pyb11_args}){ + return ${name}(${pyb11_call}); + }); +} + +''' + + +class CBackend(CythonGenerator): + def __init__(self, detect_type=ocl_detect_type, known_types=None): + super(CBackend, self).__init__() + # self.function_address_space = 'WITHIN_KERNEL ' + + def get_func_signature_pyb11(self, func): + sourcelines = getsourcelines(func)[0] + defn, lines = get_func_definition(sourcelines) + f_name, returns, args = self._analyze_method(func, lines) + pyb11_args = [] + pyb11_call = [] + c_args = [] + c_call = [] + for arg, value in args: + c_type = self.detect_type(arg, value) + c_args.append('{type} {arg}'.format(type=c_type, arg=arg)) + + c_call.append(arg) + pyb11_type = self.ctype_to_pyb11(c_type) + pyb11_args.append('{type} {arg}'.format(type=pyb11_type, arg=arg)) + if c_type.endswith('*'): + pyb11_call.append( + '({ctype}){arg}.request().ptr' + .format(arg=arg, ctype=c_type)) + else: + pyb11_call.append('{arg}'.format(arg=arg)) + + return (pyb11_args, pyb11_call), (c_args, c_call) + + def ctype_to_pyb11(self, c_type): + if c_type[-1] == '*': + return 'py::array_t<{}>'.format(c_type[:-1]) + else: + return c_type + + def _get_self_type(self): + return KnownType('GLOBAL_MEM %s*' % self._class_name) + +class CCompile(CBackend): + def __init__(self, func): + super(CCompile, self).__init__() + self.func = func + self.src = "not yet generated" + self.tp = Transpiler(backend='c') + self.c_func = self._compile() + + def _compile(self): + self.tp.add(self.func) + self.src = self.tp.get_code() + + py_data, c_data = self.get_func_signature_pyb11(self.func) + + pyb11_args = ', '.join(py_data[0][:]) + pyb11_call = ', '.join(py_data[1][:]) + hash_fn = get_md5(self.src) + modname = f'm_{hash_fn}' + template = Template(elwise_c_pybind) + src_bind = template.render( + name=self.func.__name__, + modname=modname, + pyb11_args=pyb11_args, + pyb11_call=pyb11_call + ) + self.src += src_bind + + mod = Cmodule(self.src, hash_fn, openmp=False, + extra_inc_dir=[pybind11.get_include()]) + module = mod.load() + return getattr(module, modname) + + def _massage_arg(self, x): + if isinstance(x, array.Array): + return x.dev + elif isinstance(x, np.ndarray): + return x + else: + return np.asarray(x) + + @profile + def __call__(self, *args, **kwargs): + c_args = [self._massage_arg(x) for x in args] + self.c_func(*c_args) + +elwise_c_template = ''' + +void ${name}(${arguments}){ + %if openmp: + #pragma omp parallel for + %endif + for(size_t i = 0; i < SIZE; i++){ + ${operations}; + } +} + +''' + +reduction_c_template = ''' +template +T combine(T a, T b){ + return ${red_expr}; +} + +template +T reduce_one_ar(int offset, int n, T initial_val, T* ary){ + T a, b, temp; + temp = initial_val; + + for (int i = offset; i < (n + offset); i++){ + a = temp; + b = ary[i]; + + temp = combine(a, b); + } + return temp; +} + +template +T reduce(int offset, int n, T initial_val${args_extra}){ + T a, b, temp; + temp = initial_val; + + for (int i = offset; i < (n + offset); i++){ + a = temp; + b = ${map_expr}; + + temp = combine(a, b); + } + return temp; +} + + +template +T reduce_all(long N, T initial_val${args_extra}){ + T ans = initial_val; + if (N > 0){ + %if openmp: + int ntiles = omp_get_max_threads(); + %else: + int ntiles = 1; + %endif + T* stage1_res = new T[ntiles]; + %if openmp: + #pragma omp parallel for + %endif + { + // Step 1 - reducing each tile + %if openmp: + int itile = omp_get_thread_num(); + %else: + int itile = 0; + %endif + int last_tile = ntiles - 1; + int tile_size = (N / ntiles); + int last_tile_sz = N - tile_size * last_tile; + int cur_tile_size = itile == ntiles - 1 ? last_tile_sz : tile_size; + int cur_start_idx = itile * tile_size; + + stage1_res[itile] = reduce(cur_start_idx, cur_tile_size, + initial_val${call_extra}); + %if openmp: + #pragma omp barrier + + #pragma omp single + %endif + ans = reduce_one_ar(0, ntiles, initial_val, stage1_res); + } + delete[] stage1_res; + } + return ans; +} +''' + +reduction_c_pybind = ''' + +PYBIND11_MODULE(${name}, m) { + m.def("${name}", [](long n${pyb_args}){ + return reduce_all(n, (${type})${neutral}${pyb_call}); + }); +} + +''' + +scan_c_template = ''' + +template +T combine(T a, T b){ + return ${scan_expr}; +} + + +template +T reduce( T* ary, int offset, int n, T initial_val${args_in_extra}){ + T a, b, temp; + temp = initial_val; + + for (int i = offset; i < (n + offset); i++){ + a = temp; + b = ${scan_input_expr_call}; + + temp = combine(a, b); + } + return temp; +} + +template +void excl_scan_wo_ip_exp( T* ary, T* out, int N, T initial_val){ + if (N > 0){ + T a, b, temp; + temp = initial_val; + + for (int i = 0; i < N; i++){ + a = temp; + b = ary[i]; + out[i] = temp; + temp = combine(a, b); + } + out[N] = temp; + } +} + + +template +void incl_scan( T* ary, int offset, int cur_buf_size, int N, + T initial_val, T last_item${args_extra}) +{ + if (N > 0){ + T a, b, carry, prev_item, item; + carry = initial_val; + + for (int i = offset; i < (cur_buf_size + offset); i++){ + a = carry; + b = ${scan_input_expr_call}; + prev_item = carry; + carry = combine(a, b); + item = carry; + + ${scan_output_expr_call}; + } + } +} + + +template +void scan( T* ary, long N, T initial_val${args_extra}){ + if (N > 0){ + %if openmp: + int ntiles = omp_get_max_threads(); + %else: + int ntiles = 1; + %endif + T* stage1_res = new T[ntiles]; + T* stage2_res = new T[ntiles + 1]; + %if openmp: + #pragma omp parallel + %endif + { + // Step 1 - reducing each tile + %if openmp: + int itile = omp_get_thread_num(); + %else: + int itile = 0; + %endif + int last_tile = ntiles - 1; + int tile_size = (N / ntiles); + int last_tile_sz = N - tile_size * last_tile; + int cur_tile_size = itile == ntiles - 1 ? last_tile_sz : tile_size; + int cur_start_idx = itile * tile_size; + + stage1_res[itile] = reduce(ary, cur_start_idx, cur_tile_size, + initial_val${call_in_extra}); + %if openmp: + #pragma omp barrier + + #pragma omp single + %endif + excl_scan_wo_ip_exp(stage1_res, stage2_res, + ntiles, initial_val); + + incl_scan(ary, cur_start_idx, cur_tile_size, N, + stage2_res[itile],stage2_res[ntiles]${call_extra}); + } + delete[] stage1_res; + delete[] stage2_res; + py::print(ary); + } +} +''' +scan_c_pybind = ''' + +PYBIND11_MODULE(${name}, m) { + m.def("${name}", [](py::array_t<${type}> x, long n${pyb_args}){ + return scan((${type}*) x.request().ptr, n, + (${type})${neutral}${pyb_call}); + }); +} +''' diff --git a/compyle/cimport.py b/compyle/cimport.py new file mode 100644 index 0000000..16b21f3 --- /dev/null +++ b/compyle/cimport.py @@ -0,0 +1,129 @@ +import os +import io +import importlib +import shutil +import sys +from filelock import FileLock + +from os.path import exists, expanduser, isdir, join + +import pybind11 +from distutils.extension import Extension +from distutils.command import build_ext +from distutils.core import setup +from distutils.errors import CompileError, LinkError + +from .ext_module import get_platform_dir, get_ext_extension, get_openmp_flags +from .capture_stream import CaptureMultipleStreams # noqa: 402 + + +class Cmodule: + def __init__(self, src, hash_fn, root=None, verbose=False, openmp=False, + extra_inc_dir=[pybind11.get_include()], + extra_link_args=[], extra_compile_args=[]): + self.src = src + self.hash = hash_fn + self.name = f'm_{self.hash}' + self.verbose = verbose + self.openmp = openmp + self.extra_inc_dir = extra_inc_dir + self.extra_link_args = extra_link_args + self.extra_compile_args = extra_compile_args + self._use_cpp11() + + self._setup_root(root) + self._setup_filenames() + self.lock = FileLock(self.lock_path, timeout=120) + + def _setup_root(self, root): + if root is None: + plat_dir = get_platform_dir() + self.root = expanduser(join('~', '.compyle', 'source', plat_dir)) + else: + self.root = root + + self.build_dir = join(self.root, 'build') + + if not isdir(self.build_dir): + try: + os.makedirs(self.build_dir) + except OSError: + pass + + def _write_source(self): + if not exists(self.src_path): + with io.open(self.src_path, 'w', encoding='utf-8') as f: + f.write(self.src) + + def _setup_filenames(self): + self.src_path = join(self.root, self.name + '.cpp') + self.ext_path = join(self.root, self.name + get_ext_extension()) + self.lock_path = join(self.root, self.name + '.lock') + + def is_build_needed(self): + return not exists(self.ext_path) + + def build(self): + self._include_openmp() + ext = Extension(name=self.name, + sources=[self.src_path], + language='c++', + include_dirs=self.extra_inc_dir, + extra_link_args=self.extra_link_args, + extra_compile_args=self.extra_compile_args) + args = [ + "build_ext", + "--build-lib=" + self.build_dir, + "--build-temp=" + self.build_dir, + "-v", + ] + + try: + with CaptureMultipleStreams() as stream: + setup(name=self.name, + ext_modules=[ext], + script_args=args, + cmdclass={"build_ext": build_ext.build_ext}) + shutil.move(join(self.build_dir, self.name + + get_ext_extension()), self.ext_path) + + except(CompileError, LinkError, SystemExit): + hline = "*"*80 + print(hline + "\nERROR") + s_out = stream.get_output() + print(s_out[0]) + print(s_out[1]) + msg = "Compilation of code failed, please check "\ + "error messages above." + print(hline + "\n" + msg) + sys.exit(1) + + def write_and_build(self): + """Write source and build the extension module""" + if self.is_build_needed(): + with self.lock: + self._write_source() + self.build() + else: + self._message("Precompiled code from:", self.src_path) + + def load(self): + self.write_and_build() + spec = importlib.util.spec_from_file_location(self.name, self.ext_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def _include_openmp(self): + if self.openmp: + ec, el = get_openmp_flags() + self.extra_compile_args += ec + self.extra_link_args += el + + def _use_cpp11(self): + self.extra_compile_args += ['-std=c++11'] + + def _message(self, *args): + msg = ' '.join(args) + if self.verbose: + print(msg) diff --git a/compyle/jit.py b/compyle/jit.py index 080fd42..bf0ee4f 100644 --- a/compyle/jit.py +++ b/compyle/jit.py @@ -5,7 +5,7 @@ import ast import importlib import warnings -import time +import json from pytools import memoize from .config import get_config from .cython_generator import CythonGenerator @@ -368,6 +368,8 @@ def __call__(self, *args, **kw): c_func(*c_args, **kw) event.record() event.synchronize() + elif self.backend == 'c': + c_func(*c_args) class ReductionJIT(parallel.ReductionBase): @@ -448,6 +450,10 @@ def __call__(self, *args, **kw): event.record() event.synchronize() return result.get() + elif self.backend == 'c': + size = len(c_args[0]) + c_args.insert(0, size) + return c_func(*c_args) class ScanJIT(parallel.ScanBase): @@ -567,3 +573,7 @@ def __call__(self, **kwargs): c_func(*[c_args_dict[k] for k in output_arg_keys]) event.record() event.synchronize() + elif self.backend == 'c': + size = len(c_args_dict[output_arg_keys[0]]) + c_args_dict['N'] = size + c_func(*[c_args_dict[k] for k in output_arg_keys]) diff --git a/compyle/parallel.py b/compyle/parallel.py index b124735..e66a9ef 100644 --- a/compyle/parallel.py +++ b/compyle/parallel.py @@ -6,17 +6,21 @@ """ +from compyle import c_backend from functools import wraps from textwrap import wrap from mako.template import Template import numpy as np +import pybind11 +from .cimport import Cmodule from .config import get_config from .profile import profile from .cython_generator import get_parallel_range, CythonGenerator from .transpiler import Transpiler, convert_to_float_if_needed -from .types import dtype_to_ctype +from .types import TYPES, annotate, dtype_to_ctype +from .ext_module import get_md5 from . import array @@ -503,6 +507,52 @@ def _generate(self, declarations=None): # FIXME: it is difficult to get the sources from pycuda. self.all_source = self.source return knl + elif self.backend == 'c': + self.pyb11_backend = c_backend.CBackend() + py_data, c_data = self.pyb11_backend.get_func_signature_pyb11( + self.func) + pyb11_args = ', '.join(py_data[0][1:]) + size = '{arg}.request().size'.format(arg=c_data[1][1]) + pyb11_call = ', '.join([size] + py_data[1][1:]) + c_defn = ['size_t SIZE'] + c_data[0][1:] + arguments = ', '.join(c_defn) + name = self.func.__name__ + expr = '{func}({args})'.format( + func=name, + args=', '.join(c_data[1]) + ) + + openmp = self._config.use_openmp + + templete_elwise = Template(c_backend.elwise_c_template) + src_elwise = templete_elwise.render( + name=self.name, + arguments=arguments, + openmp=openmp, + operations=expr + ) + + self.source = self.tp.get_code() + if openmp: + self.source = '#include \n' + self.source + self.all_source = self.source + '\n' + src_elwise + hash_fn = get_md5(self.all_source) + modname = f'm_{hash_fn}' + + template = Template(c_backend.elwise_c_pybind) + src_bind = template.render( + name=self.name, + modname=modname, + pyb11_args=pyb11_args, + pyb11_call=pyb11_call + ) + + self.all_source += src_bind + + mod = Cmodule(self.all_source, hash_fn, openmp=openmp, + extra_inc_dir=[pybind11.get_include()]) + module = mod.load() + return getattr(module, modname) def _correct_opencl_address_space(self, c_data): code = self.tp.blocks[-1].code.splitlines() @@ -552,6 +602,8 @@ def __call__(self, *args, **kw): self.c_func(*c_args, **kw) event.record() event.synchronize() + elif self.backend == 'c': + self.c_func(*c_args) class Elementwise(object): @@ -662,6 +714,71 @@ def _generate(self, declarations=None): self.tp.compile() self.all_source = self.tp.source return getattr(self.tp.mod, 'py_' + self.name) + elif self.backend == 'c': + self.pyb11_backend = c_backend.CBackend() + if self.func is not None: + self.func.__annotations__['return'] = TYPES[self.type] + self.tp.add(self.func, declarations=declarations) + pyb_data, c_data = self.pyb11_backend.get_func_signature_pyb11( + self.func) + c_call = c_data[1] + + c_call_default = ['N', 'neutral'] + predefined_vars = ['i'] + c_call_default + c_args_extra = [[], []] + pyb_args_extra = [[], []] + for i, var in enumerate(c_call[1:]): + if var not in predefined_vars: + c_args_extra[0].append(c_data[0][i + 1]) + c_args_extra[1].append(var) + pyb_args_extra[0].append(pyb_data[0][i + 1]) + pyb_args_extra[1].append(pyb_data[1][i + 1]) + c_args_extra_str = ", " + ', '.join(c_args_extra[0]) + c_call_extra_str = ", " + ', '.join(c_args_extra[1]) + pyb_args_extra_str = ", " + ', '.join(pyb_args_extra[0]) + pyb_call_extra_str = ", " + ', '.join(pyb_args_extra[1]) + map_expr = f"{self.func.__name__}({', '.join(c_call)})" + else: + c_args_extra_str = f", {self.type + '*'} in" + c_call_extra_str = ", in" + arg_typ = self.pyb11_backend.ctype_to_pyb11(self.type + '*') + pyb_args_extra_str = f", {arg_typ} in" + pyb_call_extra_str = f", ({self.type}*) in.request().ptr" + map_expr = "in[i]" + self.source = self.tp.get_code() + openmp = self._config.use_openmp + if openmp: + self.source = '#include \n' + self.source + + template_red = Template(c_backend.reduction_c_template) + src_red = template_red.render( + args_extra=c_args_extra_str, + call_extra=c_call_extra_str, + map_expr=map_expr, + red_expr=self.reduce_expr, + name=self.name, + type=self.type, + openmp=openmp + ) + self.all_source = self.source + src_red + hash_fn = get_md5(self.all_source) + modname = f'm_{hash_fn}' + + template_pybind = Template(c_backend.reduction_c_pybind) + src_pybind = template_pybind.render( + name=modname, + type=self.type, + pyb_args=pyb_args_extra_str, + pyb_call=pyb_call_extra_str, + neutral=self.neutral, + ) + self.all_source += src_pybind + + mod = Cmodule(self.all_source, hash_fn, openmp=openmp, + extra_inc_dir=[pybind11.get_include()]) + module = mod.load() + return getattr(module, modname) + elif self.backend == 'opencl': if self.func is not None: self.tp.add(self.func, declarations=declarations) @@ -815,6 +932,10 @@ def __call__(self, *args): event.record() event.synchronize() return result.get() + elif self.backend == 'c': + size = len(c_args[0]) + c_args.insert(0, size) + return self.c_func(*c_args) class Reduction(object): @@ -939,6 +1060,8 @@ def _generate(self, declarations=None): return self._generate_cuda_kernel(declarations=declarations) elif self.backend == 'cython': return self._generate_cython_code(declarations=declarations) + elif self.backend == 'c': + return self._generate_c_code(declarations=declarations) def _default_cython_input_function(self): py_data = (['int i', '{type}[:] input'.format(type=self.type)], @@ -1052,6 +1175,114 @@ def _generate_cython_code(self, declarations=None): self.all_source = self.tp.source return getattr(self.tp.mod, 'py_' + self.name) + def _generate_c_code(self, declarations=None): + self.pyb11_backend = c_backend.CBackend() + if not self.input_func: + @annotate(i='int', ary=f'{self.type}p', return_=f'{self.type}') + def input_expr(i, ary): + return ary[i] + self.input_func = input_expr + + self.tp.add(self.input_func, declarations=declarations) + pyb_data_in, c_data_in = self.pyb11_backend.get_func_signature_pyb11( + self.input_func) + self.tp.add(self.output_func, declarations=declarations) + self.source = self.tp.get_code() + openmp = self._config.use_openmp + if openmp: + self.source = '#include \n' + self.source + c_call_in = c_data_in[1] + pyb_data_out, c_data_out = self.pyb11_backend.get_func_signature_pyb11( + self.output_func) + c_call_out = c_data_out[1] + c_call_default = ['ary', 'N'] + c_internal_var = ['item', 'prev_item', 'last_item'] + predefined_vars = c_call_default + c_internal_var + + c_args_in_extra = [[], []] + pyb_args_in_extra = [[], []] + for i, var in enumerate(c_call_in[1:]): + if var not in predefined_vars: + c_args_in_extra[0].append(c_data_in[0][i + 1]) + c_args_in_extra[1].append(var) + pyb_args_in_extra[0].append(pyb_data_in[0][i + 1]) + pyb_args_in_extra[1].append(pyb_data_in[1][i + 1]) + c_args_out_extra = [[], []] + pyb_args_extra = [[], []] + for i, var in enumerate(c_call_out[1:]): + if var not in predefined_vars: + c_args_out_extra[0].append(c_data_out[0][i + 1]) + c_args_out_extra[1].append(var) + pyb_args_extra[0].append(pyb_data_out[0][i + 1]) + pyb_args_extra[1].append(pyb_data_out[1][i + 1]) + + c0 = c_args_in_extra[0] + c1 = c_args_in_extra[1] + c_args_in_extra_str = f", {','.join(c0)}" if c1 else "" + c_call_in_extra_str = f", {','.join(c1)}" if c1 else "" + + c_args_extra = c_args_out_extra.copy() + for i, var in enumerate(c_args_in_extra[1]): + if var not in c_args_extra[1]: + c_args_extra[0].append(c_args_in_extra[0][i]) + c_args_extra[1].append(var) + pyb_args_extra[0].append(pyb_args_in_extra[0][i]) + pyb_args_extra[1].append(pyb_args_in_extra[1][i]) + + if not hasattr(self.output_func, 'arg_keys'): + self.output_func.arg_keys = {} + self.output_func.arg_keys[self._get_backend_key( + )] = c_call_default + c_args_extra[1] + + c0 = c_args_extra[0] + c1 = c_args_extra[1] + p0 = pyb_args_extra[0] + p1 = pyb_args_extra[1] + c_args_extra_str = f", {', '.join(c0)}" if c1 else "" + c_call_extra_str = f", {', '.join(c1)}" if c1 else "" + pyb_args_extra_str = f", {', '.join(p0)}" if p1 else "" + pyb_call_extra_str = f", {', '.join(p1)}" if p1 else "" + + ip_fname = self.input_func.__name__ + op_fname = self.output_func.__name__ + c_call_in_str = f"{ip_fname}({', '.join(c_call_in)})" + c_call_out_str = f"{op_fname}({', '.join(c_call_out)})" + + template_scan = Template(c_backend.scan_c_template) + src_scan = template_scan.render( + scan_expr=self.scan_expr, + scan_input_expr_call=c_call_in_str, + scan_output_expr_call=c_call_out_str, + args_extra=c_args_extra_str, + args_in_extra=c_args_in_extra_str, + call_extra=c_call_extra_str, + call_in_extra=c_call_in_extra_str, + name=self.name, + type=self.type, + pyb_args=pyb_args_extra_str, + pyb_call=pyb_call_extra_str, + openmp=openmp + ) + self.all_source = self.source + src_scan + hash_fn = get_md5(self.all_source) + modname = f'm_{hash_fn}' + + pybind_template = Template(c_backend.scan_c_pybind) + src_pybind = pybind_template.render( + name=modname, + type=self.type, + pyb_args=pyb_args_extra_str, + pyb_call=pyb_call_extra_str, + neutral=self.neutral + ) + + self.all_source += src_pybind + + mod = Cmodule(self.all_source, hash_fn, openmp=openmp, + extra_inc_dir=[pybind11.get_include()]) + module = mod.load() + return getattr(module, modname) + def _wrap_ocl_function(self, func, func_type=None, declarations=None): if func is not None: self.tp.add(func, declarations=declarations) @@ -1234,6 +1465,10 @@ def __call__(self, **kwargs): self.c_func(*[c_args_dict[k] for k in output_arg_keys]) event.record() event.synchronize() + elif self.backend == 'c': + size = len(c_args_dict[output_arg_keys[0]]) + c_args_dict['N'] = size + self.c_func(*[c_args_dict[k] for k in output_arg_keys]) class Scan(object): diff --git a/compyle/tests/test_c_backend.py b/compyle/tests/test_c_backend.py new file mode 100644 index 0000000..8b1c82d --- /dev/null +++ b/compyle/tests/test_c_backend.py @@ -0,0 +1,43 @@ +import unittest +from unittest import TestCase +from ..c_backend import CBackend, CCompile +from ..types import annotate +import numpy as np + + +class TestCBackend(TestCase): + def test_get_func_signature(self): + cbackend = CBackend() + + @annotate(x='int', y='intp', z='int', w='double') + def test_fn(x, y, z=2, w=3.0): + return x+y+z+w + temp = cbackend.get_func_signature(test_fn) + (pyb11_args, pyb11_call), (c_args, c_call) = temp + exp_pyb11_args = ['int x', 'int[:] y', 'int z', 'double w'] + exp_pyb11_call = ['x', '&y[0]', 'z', 'w'] + exp_c_args = ['int x', 'int* y', 'int z', 'double w'] + exp_c_call = ['x', 'y', 'z', 'w'] + + self.assertListEqual(pyb11_args, exp_pyb11_args) + self.assertListEqual(pyb11_call, exp_pyb11_call) + self.assertListEqual(c_args, exp_c_args) + self.assertListEqual(c_call, exp_c_call) + +class TestCCompile(TestCase): + def test_compile(self): + @annotate(int='n, p', intp='x, y') + def get_pow(n, p, x, y): + for i in range(n): + y[i] = x[i]**p + c_get_pow = CCompile(get_pow) + n = 5 + p = 5 + x = np.ones(n, dtype=np.int32) * 2 + y = np.zeros(n, dtype=np.int32) + y_exp = np.ones(n, dtype=np.int32) * 32 + c_get_pow(n, p, x, y) + assert(np.all(y == y_exp)) + +if __name__ == '__main__': + unittest.main() diff --git a/compyle/tests/test_cimport.py b/compyle/tests/test_cimport.py new file mode 100644 index 0000000..9d0d5b9 --- /dev/null +++ b/compyle/tests/test_cimport.py @@ -0,0 +1,65 @@ +from genericpath import exists +from ntpath import join +import tempfile +import unittest +from unittest import TestCase +import numpy as np +from os.path import exists, expanduser, isdir, join +import sys +import os +from mako.template import Template + + +from compyle.cimport import Cmodule +from compyle.types import annotate +from compyle.ext_module import get_platform_dir, get_md5, get_ext_extension + +dummy_module = ''' +#include +#include +namespace py = pybind11; + +void f(int n, int* x, int* y) +{ + for(int i = 0; i < n; i++){ + y[i] = (2 * x[i]); + } +} +''' +pybind = """ + +PYBIND11_MODULE(${name}, m) { + + m.def("${name}", [](py::array_t x, py::array_t y){ + return f(x.request().size, (int*)x.request().ptr, + (int*)y.request().ptr); + }); +} +""" + + +class TestCmodule(TestCase): + def setUp(self): + self.root = tempfile.mkdtemp() + + def test_build(self): + hash_fn = get_md5(dummy_module) + name = f'm_{hash_fn}' + pyb_template = Template(pybind) + src_pybind = pyb_template.render(name=name) + + all_src = dummy_module + src_pybind + mod = Cmodule(all_src, hash_fn=hash_fn, root=self.root) + checksum = get_md5(dummy_module) + self.assertTrue(mod.is_build_needed()) + + mod.load() + self.assertTrue(exists(join(self.root, 'build'))) + self.assertTrue(exists(join(self.root, 'm_' + checksum + '.cpp'))) + self.assertTrue( + exists(join(self.root, f'{name}' + get_ext_extension()))) + self.assertFalse(mod.is_build_needed()) + + +if __name__ == '__main__': + unittest.main() diff --git a/compyle/tests/test_parallel.py b/compyle/tests/test_parallel.py index 8fed025..ce1667a 100644 --- a/compyle/tests/test_parallel.py +++ b/compyle/tests/test_parallel.py @@ -19,6 +19,65 @@ def external(x): return x +class ParallelUtilsBaseC(object): + def test_elementwise_works_with_c(self): + self._check_simple_elementwise(backend='c') + + def test_elementwise_works_with_global_constant_c(self): + self._check_elementwise_with_constant(backend='c') + + def test_reduction_works_without_map_c(self): + self._check_simple_reduction(backend='c') + + def test_reduction_works_with_map_c(self): + self._check_reduction_with_map(backend='c') + + def test_reduction_works_with_external_func_c(self): + self._check_reduction_with_external_func(backend='c') + + def test_reduction_works_neutral_c(self): + self._check_reduction_min(backend='c') + + def test_scan_works_c(self): + self._test_scan(backend='c') + + def test_scan_works_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan(backend='c') + + def test_large_scan_works_c_parallel(self): + with use_config(use_openmp=True): + self._test_large_scan(backend='c') + + def test_scan_works_with_external_func_c(self): + self._test_scan_with_external_func(backend='c') + + def test_scan_works_with_external_func_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan_with_external_func(backend='c') + + def test_scan_last_item_c_parallel(self): + with use_config(use_openmp=True): + self._test_scan_last_item(backend='c') + + def test_scan_last_item_c_serial(self): + self._test_scan_last_item(backend='c') + + def test_unique_scan_c(self): + self._test_unique_scan(backend='c') + + def test_unique_scan_c_parallel(self): + with use_config(use_openmp=True): + self._test_unique_scan(backend='c') + + def test_repeated_scans_with_different_settings_c(self): + with use_config(use_openmp=False): + self._test_unique_scan(backend='c') + + with use_config(use_openmp=True): + self._test_unique_scan(backend='c') + + class ParallelUtilsBase(object): def test_elementwise_works_with_cython(self): self._check_simple_elementwise(backend='cython') @@ -221,7 +280,9 @@ def test_repeated_scans_with_different_settings(self): self._test_unique_scan(backend='cython') -class TestParallelUtils(ParallelUtilsBase, unittest.TestCase): +class TestParallelUtils(ParallelUtilsBase, + ParallelUtilsBaseC, + unittest.TestCase): def setUp(self): cfg = get_config() self._use_double = cfg.use_double diff --git a/compyle/transpiler.py b/compyle/transpiler.py index 46663e6..bc59619 100644 --- a/compyle/transpiler.py +++ b/compyle/transpiler.py @@ -8,7 +8,7 @@ from .config import get_config from .ast_utils import get_unknown_names_and_calls from .cython_generator import CythonGenerator, CodeGenerationError -from .translator import OpenCLConverter, CUDAConverter +from .translator import OpenCLConverter, CUDAConverter, CConverter from .ext_module import ExtModule from .extern import Extern, get_extern_code from .utils import getsourcelines @@ -187,6 +187,15 @@ def __init__(self, backend='cython', incl_cluda=True): #define max(x, y) fmax((double)(x), (double)(y)) ''') + elif backend == 'c': + self._cgen = CConverter() + self.header = dedent(''' + // c code for with PyBind11 binding + #include + #include + namespace py = pybind11; + using namespace std; + ''') def _handle_symbol(self, name, value): backend = self.backend @@ -216,6 +225,8 @@ def _handle_symbol(self, name, value): return '#define {name} {value}'.format( name=name, value=value ) + elif self.backend == 'c': + return f"{ctype} {name} = {value};" def _get_comment(self): return '#' if self.backend == 'cython' else '//' @@ -278,6 +289,10 @@ def add(self, obj, declarations=None): code = self._cgen.parse( obj, declarations=declarations.get(obj.__name__) if declarations else None) + elif self.backend == 'c': + code = self._cgen.parse( + obj, declarations=declarations.get(obj.__name__) + if declarations else None) cb = CodeBlock(obj, code) self.blocks.append(cb) diff --git a/requirements.txt b/requirements.txt index db7fea3..ae429d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ pytools cython numpy pytest +pybind11 +filelock \ No newline at end of file