14
14
import time
15
15
16
16
from .config import get_config
17
- from .profile import profile
17
+ from .profile import record_flops , profile
18
18
from .cython_generator import get_parallel_range , CythonGenerator
19
19
from .transpiler import Transpiler , convert_to_float_if_needed
20
20
from .types import dtype_to_ctype
@@ -405,11 +405,12 @@ def get_common_cache_key(obj):
405
405
class ElementwiseBase (object ):
406
406
def __init__ (self , func , backend = None ):
407
407
backend = array .get_backend (backend )
408
- self .tp = Transpiler (backend = backend )
408
+ self ._config = get_config ()
409
+ self .tp = Transpiler (backend = backend ,
410
+ count_flops = self ._config .count_flops )
409
411
self .backend = backend
410
412
self .name = 'elwise_%s' % func .__name__
411
413
self .func = func
412
- self ._config = get_config ()
413
414
self .cython_gen = CythonGenerator ()
414
415
self .queue = None
415
416
self .c_func = self ._generate ()
@@ -446,11 +447,17 @@ def _generate(self, declarations=None):
446
447
ctx = get_context ()
447
448
self .queue = get_queue ()
448
449
name = self .func .__name__
450
+ call_args = ', ' .join (c_data [1 ])
451
+ if self ._config .count_flops :
452
+ call_args += ', cpy_flop_counter'
449
453
expr = '{func}({args})' .format (
450
454
func = name ,
451
- args = ', ' . join ( c_data [ 1 ])
455
+ args = call_args
452
456
)
453
457
arguments = convert_to_float_if_needed (', ' .join (c_data [0 ][1 :]))
458
+ if self ._config .count_flops :
459
+ arguments += ', long* cpy_flop_counter'
460
+
454
461
preamble = convert_to_float_if_needed (self .tp .get_code ())
455
462
cluda_preamble = Template (text = CLUDA_PREAMBLE ).render (
456
463
double_support = True
@@ -472,11 +479,17 @@ def _generate(self, declarations=None):
472
479
from pycuda .elementwise import ElementwiseKernel
473
480
from pycuda ._cluda import CLUDA_PREAMBLE
474
481
name = self .func .__name__
482
+ call_args = ', ' .join (c_data [1 ])
483
+ if self ._config .count_flops :
484
+ call_args += ', cpy_flop_counter'
475
485
expr = '{func}({args})' .format (
476
486
func = name ,
477
- args = ', ' . join ( c_data [ 1 ])
487
+ args = call_args
478
488
)
479
489
arguments = convert_to_float_if_needed (', ' .join (c_data [0 ][1 :]))
490
+ if self ._config .count_flops :
491
+ arguments += ', long* cpy_flop_counter'
492
+
480
493
preamble = convert_to_float_if_needed (self .tp .get_code ())
481
494
cluda_preamble = Template (text = CLUDA_PREAMBLE ).render (
482
495
double_support = True
@@ -504,6 +517,8 @@ def _add_address_space(arg):
504
517
return arg
505
518
506
519
args = [_add_address_space (arg ) for arg in c_data [0 ]]
520
+ if self ._config .count_flops :
521
+ args .append ('GLOBAL_MEM long* cpy_flop_counter' )
507
522
code [:header_idx ] = wrap (
508
523
'WITHIN_KERNEL void {func}({args})' .format (
509
524
func = self .func .__name__ ,
@@ -512,6 +527,14 @@ def _add_address_space(arg):
512
527
width = 78 , subsequent_indent = ' ' * 4 , break_long_words = False
513
528
)
514
529
self .tp .blocks [- 1 ].code = '\n ' .join (code )
530
+ if self ._config .count_flops :
531
+ for idx , block in enumerate (self .tp .blocks [:- 1 ]):
532
+ self .tp .blocks [idx ].code = block .code .replace (
533
+ '${offset}' , '0'
534
+ )
535
+ self .tp .blocks [- 1 ].code = self .tp .blocks [- 1 ].code .replace (
536
+ '${offset}' , 'i'
537
+ )
515
538
516
539
def _massage_arg (self , x ):
517
540
if isinstance (x , array .Array ):
@@ -524,6 +547,10 @@ def _massage_arg(self, x):
524
547
@profile
525
548
def __call__ (self , * args , ** kw ):
526
549
c_args = [self ._massage_arg (x ) for x in args ]
550
+ if self ._config .count_flops :
551
+ flop_counter = array .zeros (args [0 ].length , np .int64 ,
552
+ backend = self .backend )
553
+ c_args .append (flop_counter .dev )
527
554
if self .backend == 'cython' :
528
555
size = len (c_args [0 ])
529
556
c_args .insert (0 , size )
@@ -537,6 +564,9 @@ def __call__(self, *args, **kw):
537
564
self .c_func (* c_args , ** kw )
538
565
event .record ()
539
566
event .synchronize ()
567
+ if self ._config .count_flops :
568
+ flops = array .sum (flop_counter )
569
+ record_flops (self .name , flops )
540
570
541
571
542
572
class Elementwise (object ):
@@ -1113,11 +1143,12 @@ def _massage_arg(self, x):
1113
1143
def __call__ (self , ** kwargs ):
1114
1144
c_args_dict = {k : self ._massage_arg (x ) for k , x in kwargs .items ()}
1115
1145
if self ._get_backend_key () in self .output_func .arg_keys :
1116
- output_arg_keys = self .output_func .arg_keys [self ._get_backend_key ()]
1146
+ output_arg_keys = self .output_func .arg_keys [self ._get_backend_key (
1147
+ )]
1117
1148
else :
1118
1149
raise ValueError ("No kernel arguments found for backend = %s, "
1119
- "use_openmp = %s, use_double = %s" %
1120
- self ._get_backend_key ())
1150
+ "use_openmp = %s, use_double = %s" %
1151
+ self ._get_backend_key ())
1121
1152
1122
1153
if self .backend == 'cython' :
1123
1154
size = len (c_args_dict [output_arg_keys [1 ]])
0 commit comments