Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 9deb029

Browse files
Adding debug decorator to print Numba compile stats (#925)
* Adding debug decorator to print Numba compile stats This adds a decorator to print Numba compilation stats for a decorated function and all nested functions (literally all compiled overloads). Filtering can be made using decorator arguments. NOTE: this may show a bit lesser times than compile time calculated using common approach via first_run_exec_time - second_run_exec_time. * Adding tests for compile time log format * Remove check for CPUDispatcher to work with numba DPPL
1 parent a39d73d commit 9deb029

4 files changed

Lines changed: 202 additions & 0 deletions

File tree

sdc/decorators.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import numba
3232
import sdc
3333

34+
from functools import wraps
35+
from sdc.utilities.utils import print_compile_times
36+
3437

3538
def jit(signature_or_function=None, **options):
3639

@@ -44,3 +47,29 @@ def jit(signature_or_function=None, **options):
4447
Use Numba compiler pipeline
4548
'''
4649
return numba.jit(signature_or_function, **options)
50+
51+
52+
def debug_compile_time(level=1, func_names=None):
53+
""" Decorates Numba Dispatcher object to print compile stats after call.
54+
Usage:
55+
@debug_compile_time()
56+
@numba.njit
57+
<decorated function>
58+
Args:
59+
level: if zero prints only short summary
60+
func_names: filters output to include only functions which names include listed strings,
61+
"""
62+
63+
def get_wrapper(disp):
64+
65+
@wraps(disp)
66+
def wrapper(*args, **kwargs):
67+
res = disp(*args, **kwargs)
68+
print('*' * 40, 'COMPILE STATS', '*' * 40)
69+
print_compile_times(disp, level=level, func_names=func_names)
70+
print('*' * 95)
71+
return res
72+
73+
return wrapper
74+
75+
return get_wrapper

sdc/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from sdc.tests.test_sdc_numpy import *
5050
from sdc.tests.test_prange_utils import *
51+
from sdc.tests.test_compile_time import *
5152

5253
# performance tests
5354
import sdc.tests.tests_perf

sdc/tests/test_compile_time.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2019-2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import numba
28+
import numpy as np
29+
import pandas as pd
30+
import re
31+
import unittest
32+
33+
from contextlib import redirect_stdout
34+
from io import StringIO
35+
from sdc.tests.test_base import TestCase
36+
from sdc.decorators import debug_compile_time
37+
38+
39+
# regexp patterns for lines in @debug_compile_time output log
40+
line_heading = r'\*+\s+COMPILE STATS\s+\*+\n'
41+
line_function = r'Function: [^\s]+\n'
42+
line_args = r'\s+Args:.*\n'
43+
line_pipeline = r'\s+Pipeline: \w+\n'
44+
line_passes = r'(\s+\w+\s+[\d.]+\n)+'
45+
line_time = r'\s+Time: [\d.]+\n'
46+
line_ending = r'\*+\n'
47+
48+
49+
class TestCompileTime(TestCase):
50+
51+
@staticmethod
52+
def _gen_usecase_data():
53+
n = 11
54+
S1 = pd.Series(np.ones(n))
55+
S2 = pd.Series(2 ** np.arange(n))
56+
return S1, S2
57+
58+
def test_log_format_summary(self):
59+
""" Verifies shortened log format when only summary info is printed """
60+
61+
@debug_compile_time(level=0)
62+
@self.jit
63+
def test_impl(S1, S2):
64+
return S1 + S2
65+
66+
buffer = StringIO()
67+
with redirect_stdout(buffer):
68+
S1, S2 = self._gen_usecase_data()
69+
test_impl(S1, S2)
70+
71+
entry_format = fr'{line_function}{line_pipeline}{line_time}\n'
72+
log_format = fr'^{line_heading}({entry_format})+{line_ending}$'
73+
self.assertRegex(buffer.getvalue(), log_format)
74+
75+
def test_log_format_detailed(self):
76+
""" Verifies detailed log format with passes and args information """
77+
78+
@debug_compile_time()
79+
@self.jit
80+
def test_impl(S1, S2):
81+
return S1 + S2
82+
83+
buffer = StringIO()
84+
with redirect_stdout(buffer):
85+
S1, S2 = self._gen_usecase_data()
86+
test_impl(S1, S2)
87+
88+
entry_format = fr'{line_function}{line_args}{line_pipeline}{line_passes}{line_time}\n'
89+
log_format = fr'{line_heading}({entry_format})+{line_ending}'
90+
self.assertRegex(buffer.getvalue(), log_format)
91+
92+
def test_func_names_filter(self):
93+
""" Verifies filtering log entries via func_names paramter """
94+
searched_name = 'add'
95+
96+
@debug_compile_time(func_names=[searched_name])
97+
@self.jit
98+
def test_impl(S1, S2):
99+
return S1 + S2
100+
101+
buffer = StringIO()
102+
with redirect_stdout(buffer):
103+
S1, S2 = self._gen_usecase_data()
104+
test_impl(S1, S2)
105+
106+
line_function = r'Function: ([^\s]+)\n'
107+
match_iter = re.finditer(line_function, buffer.getvalue())
108+
next(match_iter) # skip entry for top-level func
109+
for m in match_iter:
110+
self.assertIn(searched_name, m.group(1))
111+
112+
113+
if __name__ == "__main__":
114+
unittest.main()

sdc/utilities/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
from numba.extending import overload, overload_method, overload_attribute
5252
from numba.extending import register_jitable, register_model
5353
from numba.core.datamodel.registry import register_default
54+
from functools import wraps
55+
from itertools import filterfalse, chain
5456

5557

5658
# int values for types to pass to C code
@@ -686,3 +688,59 @@ def sdc_overload_attribute(typ, name, jit_options={}, parallel=None, strict=True
686688
return overload_attribute(
687689
typ, name, jit_options=jit_options, strict=strict, inline=inline, prefer_literal=prefer_literal
688690
)
691+
692+
693+
def print_compile_times(disp, level, func_names=None):
694+
695+
def print_times(cres, args):
696+
print(f'Function: {cres.fndesc.unique_name}')
697+
pad = ' ' * 2
698+
if level:
699+
print(f'{pad * 1}Args: {args}')
700+
times = cres.metadata['pipeline_times']
701+
for pipeline, pass_times in times.items():
702+
print(f'{pad * 1}Pipeline: {pipeline}')
703+
if level:
704+
for name, timings in pass_times.items():
705+
print(f'{pad * 2}{name:50s}{timings.run:.13f}')
706+
707+
pipeline_total = sum(t.init + t.run + t.finalize for t in pass_times.values())
708+
print(f'{pad * 1}Time: {pipeline_total}\n')
709+
710+
# print times for compiled function indicated by disp
711+
for args, cres in disp.overloads.items():
712+
print_times(cres, args)
713+
714+
def has_no_cache(ovld):
715+
return not (getattr(ovld, '_impl_cache', False) and ovld._impl_cache)
716+
717+
known_funcs = disp.typingctx._functions
718+
all_templs = chain.from_iterable(known_funcs.values())
719+
compiled_templs = filterfalse(has_no_cache, all_templs)
720+
721+
# filter only function names that are in the func_names list
722+
if func_names:
723+
compiled_templs = filterfalse(
724+
lambda x: not any(f in str(x) for f in func_names),
725+
compiled_templs
726+
)
727+
728+
dispatchers_list = []
729+
for template in compiled_templs:
730+
tmpl_cached_impls = template._impl_cache.values()
731+
dispatchers_list.extend(tmpl_cached_impls)
732+
733+
for impl_cache in set(dispatchers_list):
734+
# impl_cache is usually a tuple of format (dispatcher, args)
735+
# if not just skip these entires
736+
if not (isinstance(impl_cache, tuple)
737+
and len(impl_cache) == 2
738+
and isinstance(impl_cache[0], type(disp))):
739+
continue
740+
741+
fndisp, args = impl_cache
742+
if not getattr(fndisp, 'overloads', False):
743+
continue
744+
745+
cres, = list(fndisp.overloads.values())
746+
print_times(cres, args)

0 commit comments

Comments
 (0)