Skip to content

Commit 129d687

Browse files
committed
Add flop counter to elementwise for opencl/cuda
1 parent bc3f1a6 commit 129d687

File tree

9 files changed

+541
-55
lines changed

9 files changed

+541
-55
lines changed

compyle/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self):
1515
self._use_double = None
1616
self._omp_schedule = None
1717
self._profile = None
18+
self._count_flops = None
1819
self._use_local_memory = None
1920
self._wgs = None
2021
self._suppress_warnings = None
@@ -129,6 +130,19 @@ def profile(self, value):
129130
def _profile_default(self):
130131
return False
131132

133+
@property
134+
def count_flops(self):
135+
if self._count_flops is None:
136+
self._count_flops = self._count_flops_default()
137+
return self._count_flops
138+
139+
@count_flops.setter
140+
def count_flops(self, value):
141+
self._count_flops = value
142+
143+
def _count_flops_default(self):
144+
return False
145+
132146
@property
133147
def use_local_memory(self):
134148
if self._use_local_memory is None:

compyle/jit.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
dtype_to_knowntype, annotate)
1515
from .extern import Extern
1616
from .utils import getsourcelines
17-
from .profile import profile
17+
from .profile import record_flops, profile
1818

1919
from . import array
2020
from . import parallel
@@ -265,13 +265,9 @@ def visit_UnaryOp(self, node):
265265
return self.visit(node.operand)
266266

267267
def visit_Return(self, node):
268-
if isinstance(node.value, ast.Name) or \
269-
isinstance(node.value, ast.Subscript) or \
270-
isinstance(node.value, ast.Num) or \
271-
isinstance(node.value, ast.BinOp) or \
272-
isinstance(node.value, ast.Call) or \
273-
isinstance(node.value, ast.IfExp) or \
274-
isinstance(node.value, ast.UnaryOp):
268+
valid_return_expr = (ast.Name, ast.Subscript, ast.Num, ast.BinOp,
269+
ast.Call, ast.IfExp, ast.UnaryOp)
270+
if isinstance(node.value, valid_return_expr):
275271
result_type = self.visit(node.value)
276272
if result_type:
277273
self.arg_types['return_'] = result_type
@@ -287,11 +283,12 @@ def visit_Return(self, node):
287283
class ElementwiseJIT(parallel.ElementwiseBase):
288284
def __init__(self, func, backend=None):
289285
backend = array.get_backend(backend)
290-
self.tp = Transpiler(backend=backend)
286+
self._config = get_config()
287+
self.tp = Transpiler(backend=backend,
288+
count_flops=self._config.count_flops)
291289
self.backend = backend
292290
self.name = 'elwise_%s' % func.__name__
293291
self.func = func
294-
self._config = get_config()
295292
self.cython_gen = CythonGenerator()
296293
self.source = '# Code jitted, call the function to generate the code.'
297294
self.all_source = self.source
@@ -333,6 +330,10 @@ def _massage_arg(self, x):
333330
def __call__(self, *args, **kw):
334331
c_func = self._generate_kernel(*args)
335332
c_args = [self._massage_arg(x) for x in args]
333+
if self._config.count_flops:
334+
flop_counter = array.zeros(args[0].length, np.int64,
335+
backend=self.backend)
336+
c_args.append(flop_counter.dev)
336337

337338
if self.backend == 'cython':
338339
size = len(c_args[0])
@@ -347,6 +348,9 @@ def __call__(self, *args, **kw):
347348
c_func(*c_args, **kw)
348349
event.record()
349350
event.synchronize()
351+
if self._config.count_flops:
352+
flops = array.sum(flop_counter)
353+
record_flops(self.name, flops)
350354

351355

352356
class ReductionJIT(parallel.ReductionBase):
@@ -523,7 +527,7 @@ def __call__(self, **kwargs):
523527
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
524528
if self._get_backend_key() in self.output_func.arg_keys:
525529
output_arg_keys = self.output_func.arg_keys[
526-
self._get_backend_key()]
530+
self._get_backend_key()]
527531
else:
528532
raise ValueError("No kernel arguments found for backend = %s, "
529533
"use_openmp = %s, use_double = %s" %

compyle/parallel.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414

1515
from .config import get_config
16-
from .profile import profile
16+
from .profile import record_flops, profile
1717
from .cython_generator import get_parallel_range, CythonGenerator
1818
from .transpiler import Transpiler, convert_to_float_if_needed
1919
from .types import dtype_to_ctype
@@ -404,11 +404,12 @@ def get_common_cache_key(obj):
404404
class ElementwiseBase(object):
405405
def __init__(self, func, backend=None):
406406
backend = array.get_backend(backend)
407-
self.tp = Transpiler(backend=backend)
407+
self._config = get_config()
408+
self.tp = Transpiler(backend=backend,
409+
count_flops=self._config.count_flops)
408410
self.backend = backend
409411
self.name = 'elwise_%s' % func.__name__
410412
self.func = func
411-
self._config = get_config()
412413
self.cython_gen = CythonGenerator()
413414
self.queue = None
414415
# This is the source generated for the user code.
@@ -453,11 +454,17 @@ def _generate(self, declarations=None):
453454
ctx = get_context()
454455
self.queue = get_queue()
455456
name = self.func.__name__
457+
call_args = ', '.join(c_data[1])
458+
if self._config.count_flops:
459+
call_args += ', cpy_flop_counter'
456460
expr = '{func}({args})'.format(
457461
func=name,
458-
args=', '.join(c_data[1])
462+
args=call_args
459463
)
460464
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
465+
if self._config.count_flops:
466+
arguments += ', long* cpy_flop_counter'
467+
461468
preamble = convert_to_float_if_needed(self.tp.get_code())
462469
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
463470
double_support=True
@@ -483,11 +490,17 @@ def _generate(self, declarations=None):
483490
from pycuda.elementwise import ElementwiseKernel
484491
from pycuda._cluda import CLUDA_PREAMBLE
485492
name = self.func.__name__
493+
call_args = ', '.join(c_data[1])
494+
if self._config.count_flops:
495+
call_args += ', cpy_flop_counter'
486496
expr = '{func}({args})'.format(
487497
func=name,
488-
args=', '.join(c_data[1])
498+
args=call_args
489499
)
490500
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
501+
if self._config.count_flops:
502+
arguments += ', long* cpy_flop_counter'
503+
491504
preamble = convert_to_float_if_needed(self.tp.get_code())
492505
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
493506
double_support=True
@@ -519,6 +532,8 @@ def _add_address_space(arg):
519532
return arg
520533

521534
args = [_add_address_space(arg) for arg in c_data[0]]
535+
if self._config.count_flops:
536+
args.append('GLOBAL_MEM long* cpy_flop_counter')
522537
code[:header_idx] = wrap(
523538
'WITHIN_KERNEL void {func}({args})'.format(
524539
func=self.func.__name__,
@@ -527,6 +542,14 @@ def _add_address_space(arg):
527542
width=78, subsequent_indent=' ' * 4, break_long_words=False
528543
)
529544
self.tp.blocks[-1].code = '\n'.join(code)
545+
if self._config.count_flops:
546+
for idx, block in enumerate(self.tp.blocks[:-1]):
547+
self.tp.blocks[idx].code = block.code.replace(
548+
'${offset}', '0'
549+
)
550+
self.tp.blocks[-1].code = self.tp.blocks[-1].code.replace(
551+
'${offset}', 'i'
552+
)
530553

531554
def _massage_arg(self, x):
532555
if isinstance(x, array.Array):
@@ -539,6 +562,10 @@ def _massage_arg(self, x):
539562
@profile
540563
def __call__(self, *args, **kw):
541564
c_args = [self._massage_arg(x) for x in args]
565+
if self._config.count_flops:
566+
flop_counter = array.zeros(args[0].length, np.int64,
567+
backend=self.backend)
568+
c_args.append(flop_counter.dev)
542569
if self.backend == 'cython':
543570
size = len(c_args[0])
544571
c_args.insert(0, size)
@@ -552,6 +579,9 @@ def __call__(self, *args, **kw):
552579
self.c_func(*c_args, **kw)
553580
event.record()
554581
event.synchronize()
582+
if self._config.count_flops:
583+
flops = array.sum(flop_counter)
584+
record_flops(self.name, flops)
555585

556586

557587
class Elementwise(object):

compyle/profile.py

+41
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
11+
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})
1112

1213

1314
def _record_profile(name, time):
@@ -16,6 +17,12 @@ def _record_profile(name, time):
1617
_profile_info[name]['calls'] += 1
1718

1819

20+
def record_flops(name, flops):
21+
global _flops_info
22+
_flops_info[name]['flops'] += flops
23+
_flops_info[name]['calls'] += 1
24+
25+
1926
@contextmanager
2027
def profile_ctx(name):
2128
""" Context manager for profiling
@@ -54,6 +61,21 @@ def get_profile_info():
5461
return _profile_info
5562

5663

64+
def get_flops_info():
65+
global _flops_info
66+
return _flops_info
67+
68+
69+
def reset_profile_info():
70+
global _profile_info
71+
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
72+
73+
74+
def reset_flops_info():
75+
global _flops_info
76+
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})
77+
78+
5779
def print_profile():
5880
global _profile_info
5981
profile_data = sorted(_profile_info.items(), key=lambda x: x[1]['time'],
@@ -73,6 +95,25 @@ def print_profile():
7395
print("Total profiled time: %g secs" % tot_time)
7496

7597

98+
def print_flops_info():
99+
global _flops_info
100+
flops_data = sorted(_flops_info.items(), key=lambda x: x[1]['flops'],
101+
reverse=True)
102+
if len(_flops_info) == 0:
103+
print("No flops information available")
104+
return
105+
print("FLOPS info:")
106+
print("{:<40} {:<10} {:<10}".format('Function', 'N calls', 'FLOPS'))
107+
tot_flops = 0
108+
for kernel, data in flops_data:
109+
print("{:<40} {:<10} {:<10}".format(
110+
kernel,
111+
data['calls'],
112+
data['flops']))
113+
tot_flops += data['flops']
114+
print("Total FLOPS: %i" % tot_flops)
115+
116+
76117
def profile_kernel(kernel, name, backend=None):
77118
"""For profiling raw PyCUDA/PyOpenCL kernels or cython functions
78119
"""

0 commit comments

Comments
 (0)