Skip to content

Commit 8fd1c86

Browse files
committed
Add flop counter to elementwise for opencl/cuda
1 parent 5c62fa0 commit 8fd1c86

File tree

9 files changed

+469
-47
lines changed

9 files changed

+469
-47
lines changed

compyle/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self):
1515
self._use_double = None
1616
self._omp_schedule = None
1717
self._profile = None
18+
self._count_flops = None
1819
self._use_local_memory = None
1920
self._wgs = None
2021
self._suppress_warnings = None
@@ -129,6 +130,19 @@ def profile(self, value):
129130
def _profile_default(self):
130131
return False
131132

133+
@property
134+
def count_flops(self):
135+
if self._count_flops is None:
136+
self._count_flops = self._count_flops_default()
137+
return self._count_flops
138+
139+
@count_flops.setter
140+
def count_flops(self, value):
141+
self._count_flops = value
142+
143+
def _count_flops_default(self):
144+
return False
145+
132146
@property
133147
def use_local_memory(self):
134148
if self._use_local_memory is None:

compyle/jit.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
dtype_to_knowntype, annotate)
1515
from .extern import Extern
1616
from .utils import getsourcelines
17-
from .profile import profile
17+
from .profile import record_flops, profile
1818

1919
from . import array
2020
from . import parallel
@@ -287,11 +287,12 @@ def visit_Return(self, node):
287287
class ElementwiseJIT(parallel.ElementwiseBase):
288288
def __init__(self, func, backend=None):
289289
backend = array.get_backend(backend)
290-
self.tp = Transpiler(backend=backend)
290+
self._config = get_config()
291+
self.tp = Transpiler(backend=backend,
292+
count_flops=self._config.count_flops)
291293
self.backend = backend
292294
self.name = 'elwise_%s' % func.__name__
293295
self.func = func
294-
self._config = get_config()
295296
self.cython_gen = CythonGenerator()
296297
if backend == 'opencl':
297298
from .opencl import get_context, get_queue
@@ -331,6 +332,10 @@ def _massage_arg(self, x):
331332
def __call__(self, *args, **kw):
332333
c_func = self._generate_kernel(*args)
333334
c_args = [self._massage_arg(x) for x in args]
335+
if self._config.count_flops:
336+
flop_counter = array.zeros(args[0].length, np.int64,
337+
backend=self.backend)
338+
c_args.append(flop_counter.dev)
334339

335340
if self.backend == 'cython':
336341
size = len(c_args[0])
@@ -345,6 +350,9 @@ def __call__(self, *args, **kw):
345350
c_func(*c_args, **kw)
346351
event.record()
347352
event.synchronize()
353+
if self._config.count_flops:
354+
flops = array.sum(flop_counter)
355+
record_flops(self.name, flops)
348356

349357

350358
class ReductionJIT(parallel.ReductionBase):
@@ -517,7 +525,7 @@ def __call__(self, **kwargs):
517525
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
518526
if self._get_backend_key() in self.output_func.arg_keys:
519527
output_arg_keys = self.output_func.arg_keys[
520-
self._get_backend_key()]
528+
self._get_backend_key()]
521529
else:
522530
raise ValueError("No kernel arguments found for backend = %s, "
523531
"use_openmp = %s, use_double = %s" %

compyle/parallel.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import time
1515

1616
from .config import get_config
17-
from .profile import profile
17+
from .profile import record_flops, profile
1818
from .cython_generator import get_parallel_range, CythonGenerator
1919
from .transpiler import Transpiler, convert_to_float_if_needed
2020
from .types import dtype_to_ctype
@@ -405,11 +405,12 @@ def get_common_cache_key(obj):
405405
class ElementwiseBase(object):
406406
def __init__(self, func, backend=None):
407407
backend = array.get_backend(backend)
408-
self.tp = Transpiler(backend=backend)
408+
self._config = get_config()
409+
self.tp = Transpiler(backend=backend,
410+
count_flops=self._config.count_flops)
409411
self.backend = backend
410412
self.name = 'elwise_%s' % func.__name__
411413
self.func = func
412-
self._config = get_config()
413414
self.cython_gen = CythonGenerator()
414415
self.queue = None
415416
self.c_func = self._generate()
@@ -446,11 +447,17 @@ def _generate(self, declarations=None):
446447
ctx = get_context()
447448
self.queue = get_queue()
448449
name = self.func.__name__
450+
call_args = ', '.join(c_data[1])
451+
if self._config.count_flops:
452+
call_args += ', cpy_flop_counter'
449453
expr = '{func}({args})'.format(
450454
func=name,
451-
args=', '.join(c_data[1])
455+
args=call_args
452456
)
453457
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
458+
if self._config.count_flops:
459+
arguments += ', long* cpy_flop_counter'
460+
454461
preamble = convert_to_float_if_needed(self.tp.get_code())
455462
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
456463
double_support=True
@@ -472,11 +479,17 @@ def _generate(self, declarations=None):
472479
from pycuda.elementwise import ElementwiseKernel
473480
from pycuda._cluda import CLUDA_PREAMBLE
474481
name = self.func.__name__
482+
call_args = ', '.join(c_data[1])
483+
if self._config.count_flops:
484+
call_args += ', cpy_flop_counter'
475485
expr = '{func}({args})'.format(
476486
func=name,
477-
args=', '.join(c_data[1])
487+
args=call_args
478488
)
479489
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
490+
if self._config.count_flops:
491+
arguments += ', long* cpy_flop_counter'
492+
480493
preamble = convert_to_float_if_needed(self.tp.get_code())
481494
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
482495
double_support=True
@@ -504,6 +517,8 @@ def _add_address_space(arg):
504517
return arg
505518

506519
args = [_add_address_space(arg) for arg in c_data[0]]
520+
if self._config.count_flops:
521+
args.append('GLOBAL_MEM long* cpy_flop_counter')
507522
code[:header_idx] = wrap(
508523
'WITHIN_KERNEL void {func}({args})'.format(
509524
func=self.func.__name__,
@@ -512,6 +527,14 @@ def _add_address_space(arg):
512527
width=78, subsequent_indent=' ' * 4, break_long_words=False
513528
)
514529
self.tp.blocks[-1].code = '\n'.join(code)
530+
if self._config.count_flops:
531+
for idx, block in enumerate(self.tp.blocks[:-1]):
532+
self.tp.blocks[idx].code = block.code.replace(
533+
'${offset}', '0'
534+
)
535+
self.tp.blocks[-1].code = self.tp.blocks[-1].code.replace(
536+
'${offset}', 'i'
537+
)
515538

516539
def _massage_arg(self, x):
517540
if isinstance(x, array.Array):
@@ -524,6 +547,10 @@ def _massage_arg(self, x):
524547
@profile
525548
def __call__(self, *args, **kw):
526549
c_args = [self._massage_arg(x) for x in args]
550+
if self._config.count_flops:
551+
flop_counter = array.zeros(args[0].length, np.int64,
552+
backend=self.backend)
553+
c_args.append(flop_counter.dev)
527554
if self.backend == 'cython':
528555
size = len(c_args[0])
529556
c_args.insert(0, size)
@@ -537,6 +564,9 @@ def __call__(self, *args, **kw):
537564
self.c_func(*c_args, **kw)
538565
event.record()
539566
event.synchronize()
567+
if self._config.count_flops:
568+
flops = array.sum(flop_counter)
569+
record_flops(self.name, flops)
540570

541571

542572
class Elementwise(object):
@@ -1113,11 +1143,12 @@ def _massage_arg(self, x):
11131143
def __call__(self, **kwargs):
11141144
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
11151145
if self._get_backend_key() in self.output_func.arg_keys:
1116-
output_arg_keys = self.output_func.arg_keys[self._get_backend_key()]
1146+
output_arg_keys = self.output_func.arg_keys[self._get_backend_key(
1147+
)]
11171148
else:
11181149
raise ValueError("No kernel arguments found for backend = %s, "
1119-
"use_openmp = %s, use_double = %s" %
1120-
self._get_backend_key())
1150+
"use_openmp = %s, use_double = %s" %
1151+
self._get_backend_key())
11211152

11221153
if self.backend == 'cython':
11231154
size = len(c_args_dict[output_arg_keys[1]])

compyle/profile.py

+41
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
11+
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})
1112

1213

1314
def _record_profile(name, time):
@@ -16,6 +17,12 @@ def _record_profile(name, time):
1617
_profile_info[name]['calls'] += 1
1718

1819

20+
def record_flops(name, flops):
21+
global _flops_info
22+
_flops_info[name]['flops'] += flops
23+
_flops_info[name]['calls'] += 1
24+
25+
1926
@contextmanager
2027
def profile_ctx(name):
2128
""" Context manager for profiling
@@ -54,6 +61,21 @@ def get_profile_info():
5461
return _profile_info
5562

5663

64+
def get_flops_info():
65+
global _flops_info
66+
return _flops_info
67+
68+
69+
def reset_profile_info():
70+
global _profile_info
71+
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
72+
73+
74+
def reset_flops_info():
75+
global _flops_info
76+
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})
77+
78+
5779
def print_profile():
5880
global _profile_info
5981
profile_data = sorted(_profile_info.items(), key=lambda x: x[1]['time'],
@@ -73,6 +95,25 @@ def print_profile():
7395
print("Total profiled time: %g secs" % tot_time)
7496

7597

98+
def print_flops_info():
99+
global _flops_info
100+
flops_data = sorted(_flops_info.items(), key=lambda x: x[1]['flops'],
101+
reverse=True)
102+
if len(_flops_info) == 0:
103+
print("No flops information available")
104+
return
105+
print("FLOPS info:")
106+
print("{:<40} {:<10} {:<10}".format('Function', 'N calls', 'FLOPS'))
107+
tot_flops = 0
108+
for kernel, data in flops_data:
109+
print("{:<40} {:<10} {:<10}".format(
110+
kernel,
111+
data['calls'],
112+
data['flops']))
113+
tot_flops += data['flops']
114+
print("Total FLOPS: %i" % tot_flops)
115+
116+
76117
def profile_kernel(kernel, name, backend=None):
77118
"""For profiling raw PyCUDA/PyOpenCL kernels or cython functions
78119
"""

0 commit comments

Comments
 (0)