Skip to content

Add flop counter to elementwise for opencl/cuda #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions compyle/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self):
self._use_double = None
self._omp_schedule = None
self._profile = None
self._count_flops = None
self._use_local_memory = None
self._wgs = None
self._suppress_warnings = None
Expand Down Expand Up @@ -129,6 +130,19 @@ def profile(self, value):
def _profile_default(self):
return False

@property
def count_flops(self):
if self._count_flops is None:
self._count_flops = self._count_flops_default()
return self._count_flops

@count_flops.setter
def count_flops(self, value):
self._count_flops = value

def _count_flops_default(self):
return False

@property
def use_local_memory(self):
if self._use_local_memory is None:
Expand Down
26 changes: 15 additions & 11 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dtype_to_knowntype, annotate)
from .extern import Extern
from .utils import getsourcelines
from .profile import profile
from .profile import record_flops, profile

from . import array
from . import parallel
Expand Down Expand Up @@ -265,13 +265,9 @@ def visit_UnaryOp(self, node):
return self.visit(node.operand)

def visit_Return(self, node):
if isinstance(node.value, ast.Name) or \
isinstance(node.value, ast.Subscript) or \
isinstance(node.value, ast.Num) or \
isinstance(node.value, ast.BinOp) or \
isinstance(node.value, ast.Call) or \
isinstance(node.value, ast.IfExp) or \
isinstance(node.value, ast.UnaryOp):
valid_return_expr = (ast.Name, ast.Subscript, ast.Num, ast.BinOp,
ast.Call, ast.IfExp, ast.UnaryOp)
if isinstance(node.value, valid_return_expr):
result_type = self.visit(node.value)
if result_type:
self.arg_types['return_'] = result_type
Expand All @@ -287,11 +283,12 @@ def visit_Return(self, node):
class ElementwiseJIT(parallel.ElementwiseBase):
def __init__(self, func, backend=None):
backend = array.get_backend(backend)
self.tp = Transpiler(backend=backend)
self._config = get_config()
self.tp = Transpiler(backend=backend,
count_flops=self._config.count_flops)
self.backend = backend
self.name = 'elwise_%s' % func.__name__
self.func = func
self._config = get_config()
self.cython_gen = CythonGenerator()
self.source = '# Code jitted, call the function to generate the code.'
self.all_source = self.source
Expand Down Expand Up @@ -333,6 +330,10 @@ def _massage_arg(self, x):
def __call__(self, *args, **kw):
c_func = self._generate_kernel(*args)
c_args = [self._massage_arg(x) for x in args]
if self._config.count_flops:
flop_counter = array.zeros(args[0].length, np.int64,
backend=self.backend)
c_args.append(flop_counter.dev)

if self.backend == 'cython':
size = len(c_args[0])
Expand All @@ -347,6 +348,9 @@ def __call__(self, *args, **kw):
c_func(*c_args, **kw)
event.record()
event.synchronize()
if self._config.count_flops:
flops = array.sum(flop_counter)
record_flops(self.name, flops)


class ReductionJIT(parallel.ReductionBase):
Expand Down Expand Up @@ -523,7 +527,7 @@ def __call__(self, **kwargs):
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
if self._get_backend_key() in self.output_func.arg_keys:
output_arg_keys = self.output_func.arg_keys[
self._get_backend_key()]
self._get_backend_key()]
else:
raise ValueError("No kernel arguments found for backend = %s, "
"use_openmp = %s, use_double = %s" %
Expand Down
40 changes: 35 additions & 5 deletions compyle/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from .config import get_config
from .profile import profile
from .profile import record_flops, profile
from .cython_generator import get_parallel_range, CythonGenerator
from .transpiler import Transpiler, convert_to_float_if_needed
from .types import dtype_to_ctype
Expand Down Expand Up @@ -404,11 +404,12 @@ def get_common_cache_key(obj):
class ElementwiseBase(object):
def __init__(self, func, backend=None):
backend = array.get_backend(backend)
self.tp = Transpiler(backend=backend)
self._config = get_config()
self.tp = Transpiler(backend=backend,
count_flops=self._config.count_flops)
self.backend = backend
self.name = 'elwise_%s' % func.__name__
self.func = func
self._config = get_config()
self.cython_gen = CythonGenerator()
self.queue = None
# This is the source generated for the user code.
Expand Down Expand Up @@ -453,11 +454,17 @@ def _generate(self, declarations=None):
ctx = get_context()
self.queue = get_queue()
name = self.func.__name__
call_args = ', '.join(c_data[1])
if self._config.count_flops:
call_args += ', cpy_flop_counter'
expr = '{func}({args})'.format(
func=name,
args=', '.join(c_data[1])
args=call_args
)
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
if self._config.count_flops:
arguments += ', long* cpy_flop_counter'

preamble = convert_to_float_if_needed(self.tp.get_code())
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
double_support=True
Expand All @@ -483,11 +490,17 @@ def _generate(self, declarations=None):
from pycuda.elementwise import ElementwiseKernel
from pycuda._cluda import CLUDA_PREAMBLE
name = self.func.__name__
call_args = ', '.join(c_data[1])
if self._config.count_flops:
call_args += ', cpy_flop_counter'
expr = '{func}({args})'.format(
func=name,
args=', '.join(c_data[1])
args=call_args
)
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
if self._config.count_flops:
arguments += ', long* cpy_flop_counter'

preamble = convert_to_float_if_needed(self.tp.get_code())
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
double_support=True
Expand Down Expand Up @@ -519,6 +532,8 @@ def _add_address_space(arg):
return arg

args = [_add_address_space(arg) for arg in c_data[0]]
if self._config.count_flops:
args.append('GLOBAL_MEM long* cpy_flop_counter')
code[:header_idx] = wrap(
'WITHIN_KERNEL void {func}({args})'.format(
func=self.func.__name__,
Expand All @@ -527,6 +542,14 @@ def _add_address_space(arg):
width=78, subsequent_indent=' ' * 4, break_long_words=False
)
self.tp.blocks[-1].code = '\n'.join(code)
if self._config.count_flops:
for idx, block in enumerate(self.tp.blocks[:-1]):
self.tp.blocks[idx].code = block.code.replace(
'${offset}', '0'
)
self.tp.blocks[-1].code = self.tp.blocks[-1].code.replace(
'${offset}', 'i'
)

def _massage_arg(self, x):
if isinstance(x, array.Array):
Expand All @@ -539,6 +562,10 @@ def _massage_arg(self, x):
@profile
def __call__(self, *args, **kw):
c_args = [self._massage_arg(x) for x in args]
if self._config.count_flops:
flop_counter = array.zeros(args[0].length, np.int64,
backend=self.backend)
c_args.append(flop_counter.dev)
if self.backend == 'cython':
size = len(c_args[0])
c_args.insert(0, size)
Expand All @@ -552,6 +579,9 @@ def __call__(self, *args, **kw):
self.c_func(*c_args, **kw)
event.record()
event.synchronize()
if self._config.count_flops:
flops = array.sum(flop_counter)
record_flops(self.name, flops)


class Elementwise(object):
Expand Down
41 changes: 41 additions & 0 deletions compyle/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})


def _record_profile(name, time):
Expand All @@ -16,6 +17,12 @@ def _record_profile(name, time):
_profile_info[name]['calls'] += 1


def record_flops(name, flops):
global _flops_info
_flops_info[name]['flops'] += flops
_flops_info[name]['calls'] += 1


@contextmanager
def profile_ctx(name):
""" Context manager for profiling
Expand Down Expand Up @@ -54,6 +61,21 @@ def get_profile_info():
return _profile_info


def get_flops_info():
global _flops_info
return _flops_info


def reset_profile_info():
global _profile_info
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})


def reset_flops_info():
global _flops_info
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})


def print_profile():
global _profile_info
profile_data = sorted(_profile_info.items(), key=lambda x: x[1]['time'],
Expand All @@ -73,6 +95,25 @@ def print_profile():
print("Total profiled time: %g secs" % tot_time)


def print_flops_info():
global _flops_info
flops_data = sorted(_flops_info.items(), key=lambda x: x[1]['flops'],
reverse=True)
if len(_flops_info) == 0:
print("No flops information available")
return
print("FLOPS info:")
print("{:<40} {:<10} {:<10}".format('Function', 'N calls', 'FLOPS'))
tot_flops = 0
for kernel, data in flops_data:
print("{:<40} {:<10} {:<10}".format(
kernel,
data['calls'],
data['flops']))
tot_flops += data['flops']
print("Total FLOPS: %i" % tot_flops)


def profile_kernel(kernel, name, backend=None):
"""For profiling raw PyCUDA/PyOpenCL kernels or cython functions
"""
Expand Down
Loading