Skip to content

MLPerf Tiny developments #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 117 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
89c3d85
Added QConv2DBatchnorm support.
nghielme Mar 8, 2021
0ae678d
Added QConv2DBatchnorm support.
nghielme Mar 8, 2021
0846766
Added support for QConv2DBatchnorm.
nghielme Mar 8, 2021
dade7de
Fixed problems emerged from PR discussion.
nghielme Mar 9, 2021
ae5afbf
Fixed problems emerged from PR discussion.
nghielme Mar 9, 2021
02f92e2
Fixed problems emerged from PR discussion.
nghielme Mar 9, 2021
b829039
Fixed problems emerged from PR discussion.
nghielme Mar 9, 2021
20d9a62
Some improvements + testing if weights are provided already folded.
nghielme Mar 10, 2021
6e6afab
Added support for `QConv2DBatchnorm` with `use_bias=False`.
nghielme Mar 12, 2021
7649a80
Removing the check on batch size. "" -> ''.
nghielme Mar 19, 2021
a1d86b8
Last updates before merging PR.
nghielme Mar 23, 2021
5e988c2
Merge branch 'fastmachinelearning:master' into master
nghielme Jun 24, 2021
aa246b8
HLS implementation of Concatenate for io_stream
vloncar Aug 6, 2021
bee6c05
Fix concatenate3d_2 for io_parallel
vloncar Aug 6, 2021
a0b4a00
Merge branch 'concat_io_stream' of https://github.com/vloncar/hls4ml …
nghielme Aug 9, 2021
33d1e91
Merge remote-tracking branch 'origin/pynq-pr' into vivado-accel-concat
nghielme Aug 9, 2021
810e2d5
Support overriding stream depth in config
vloncar Aug 10, 2021
59e3766
Added a first version of build_prj.tcl that is able to generate the .…
nghielme Sep 2, 2021
b72260c
Merge remote-tracking branch 'vlad/pynq_enet' into fifo_depth_opt
nghielme Sep 2, 2021
306e381
Fixed the assignment of the unassigned `inp` variable
nghielme Sep 2, 2021
da2c46a
Bump version
thesps Sep 24, 2021
707dba5
First version of the optimization flow automated. Tested with ENet mo…
nghielme Sep 27, 2021
ee4733b
Fixing minor issues
nghielme Sep 27, 2021
93cca6c
Fixing minor issues
nghielme Sep 27, 2021
14fa32d
Fixing minor issues
nghielme Sep 28, 2021
fe8519f
fix conv1d io_parallel resource (#403)
jmitrevs Sep 30, 2021
038c4b6
Fixing minor issues
nghielme Oct 1, 2021
0f915bf
Handling fifos not implemented in brams
nghielme Oct 1, 2021
c76ea08
Modified `name` to `cppname` for `InplaceVariable`
nghielme Oct 1, 2021
4614bfe
Added `exit` in build_prj.tcl file
nghielme Oct 1, 2021
c136ff8
`name` -> `cppname`
nghielme Oct 1, 2021
3d48d0d
Small fix
nghielme Oct 1, 2021
fd8a618
Parallelise pytests in CI
thesps Oct 5, 2021
f5acaca
Fix GlobalPooling1D Layers (#399)
jmduarte Oct 7, 2021
231fdcb
Set appropriate data type for quantized_relu activations
thesps Sep 29, 2021
b361265
Display unsigned types properly in profiling
thesps Sep 29, 2021
c0143a1
Some fixes
nghielme Oct 18, 2021
0ea8d88
Some fixes
nghielme Oct 18, 2021
7ed6fcc
Adding one dependency
nghielme Oct 19, 2021
92bf513
Fixing bugs in the `optimize_fifos_depth` function
nghielme Oct 20, 2021
888713b
Fixing bugs in `optimize_fifos_depth` function
nghielme Oct 21, 2021
ffc9845
Testing without `log_wave -r /`
nghielme Nov 8, 2021
0cfb79c
- Testing with all build parameters true by default
nghielme Nov 8, 2021
be7dd00
- Testing with all build parameters true by default
nghielme Nov 8, 2021
f10e0c3
fix batched multiple inputs
jmduarte Oct 11, 2021
d1d5eba
Fixed 'qkeras_mnist_dense' example build problem #423
siorpaes Oct 26, 2021
997c07b
Update for pyyaml 6.0 (#435)
thesps Nov 4, 2021
d0ff8ba
`axi_stream_driver` update (#420)
nghielme Nov 5, 2021
ba3902f
Reshape fixes: don't repack stream for flatten; remove final reshape …
jmduarte Nov 9, 2021
86e4397
Reorder loops in im2col_2d_cl given resource strategy issue. Reenable…
thesps Nov 9, 2021
36cd38c
Support applying Softmax over multidimensional tensors (#384)
vloncar Nov 11, 2021
9f5c249
Disable some unsupported layers
thesps Nov 9, 2021
ab076fc
Start integrating Vivado Accelerator AXI-m backend
Nov 26, 2021
fa4bd2b
build_prj.tcl fixed for vivadoaccelerator - there will be problems fo…
nghielme Nov 26, 2021
a1384cd
update example-models
nghielme Nov 26, 2021
8f64675
Setup AXI-m backend for PYNQ-Z1 and PYNQ-Z2
Nov 28, 2021
1a64d4b
Add C drivers to PYNQ-Z1
Nov 28, 2021
c1f4670
Move data.h generator into hls4ml
Nov 29, 2021
ca26680
Merge remote-tracking branch 'hls4ml/master' into fifo_depth_opt
nghielme Dec 1, 2021
aefe0c9
Copy directory even when the directory exists
Dec 2, 2021
77ff677
String compare with '==' rather than 'is'
Dec 11, 2021
a9aeb33
Add validation besides verification to the software application (Pynq…
Jan 3, 2022
d146994
Set stack/heap size in the linker script
Jan 5, 2022
08ee073
ATTENTION: force 1 the size of the scratchpads in the AXI-M wrapper (…
Jan 5, 2022
cd27c4c
Set software for ultra96-v2
Jan 7, 2022
feef1f9
init commit of QDenseBN
julesmuhizi Jan 11, 2022
97b14ea
Added changes needed to carry out the relu_merge optimization
anmeza Jan 16, 2022
c499f47
clean up code and currently testing if checking the merged relu flag …
oliviaweng Jan 18, 2022
04a382e
remove comments and extra brace
oliviaweng Jan 19, 2022
7164673
revert back to verbose template because it achieves maximum resource …
oliviaweng Jan 19, 2022
b0aa186
Merge pull request #1 from anmeza/relu_fuse
jmduarte Jan 24, 2022
c513c91
Merge branch 'fifo_depth_opt' into gdg/axi-m
jmduarte Jan 25, 2022
1ee5708
add back cppname
jmduarte Jan 25, 2022
4c7321a
added weight transponse and qkeras quantization support
julesmuhizi Jan 25, 2022
afd8249
patch-2 - added weight transponse and qkeras quantization support
julesmuhizi Jan 25, 2022
023321d
add c drivers for pynq-z2
jmduarte Jan 25, 2022
2e069ae
remove uncessary _add_supported_quantized_objects call
julesmuhizi Jan 26, 2022
15eb287
Add scripts and application for ARTY-A7 with VivadoAccelerator backend
Feb 9, 2022
cb7fc44
Cleanup timers code for ARTY-A7
Feb 9, 2022
d1dab8d
Merge pull request #5 from GiuseppeDiGuglielmo/gdg/axi-m
jmduarte Feb 9, 2022
fd4df8e
Merge pull request #2 from hls4ml-finn-mlperftiny/gdg/axi-m
jmduarte Feb 14, 2022
361628d
Integration of features for EEMBC power analysis.
Feb 16, 2022
17fc6f8
Merge pull request #7 from hls4ml-finn-mlperftiny/fifo_depth_opt
julesmuhizi Feb 18, 2022
a497085
try out jules's fix to relu_merge optimizer pass to work with dense-r…
oliviaweng Feb 21, 2022
c9daa06
have dense layer check merged_relu flag to see whether or not to set …
oliviaweng Feb 21, 2022
949bafd
fix how merged_relu flag is set for Dense and Conv2D in hls_layers.py
oliviaweng Feb 21, 2022
8e3df9e
debug print
oliviaweng Feb 21, 2022
f298517
add merged_relu attribute to Layer class
oliviaweng Feb 21, 2022
f6bf06a
remove debug prints and comments and DenseBatchnorm from relu_merge s…
oliviaweng Feb 21, 2022
9c08a24
Delete relu_merge_old.py
jmduarte Feb 22, 2022
77a2db1
Update core.py
jmduarte Feb 22, 2022
8b494a2
Update core.py
jmduarte Feb 22, 2022
a18a204
apply patches
jmduarte Feb 22, 2022
3590aa4
remove unnecessary patches
jmduarte Feb 22, 2022
4173797
Merge pull request #8 from anmeza/relu_fuse
jmduarte Feb 22, 2022
b2f499c
Merge pull request #10 from hls4ml-finn-mlperftiny/div_by_256_patch
jmduarte Feb 22, 2022
ad311fc
Merge pull request #4 from hls4ml-finn-mlperftiny/qdense_batchnorm
jmduarte Feb 22, 2022
532aff8
small fixes
jmduarte Feb 22, 2022
541acac
Merge pull request #6 from hls4ml-finn-mlperftiny/fifo_depth_opt_dev
jmduarte Feb 22, 2022
e48ae63
just pass hls_model
jmduarte Feb 22, 2022
c18064c
no need to pass keras model
jmduarte Feb 22, 2022
d8b435d
correct scope for arty
jmduarte Feb 24, 2022
8c51558
weird hack for now
jmduarte Feb 28, 2022
5f3b468
remove prints
jmduarte Mar 1, 2022
6c0d205
Merge pull request #11 from hls4ml-finn-mlperftiny/arty_fifo
jmduarte Mar 1, 2022
e808a05
Support overriding stream depth in config
vloncar Aug 10, 2021
e58a986
bug fixes
jmduarte Mar 2, 2022
6bbf88c
get correct type for data header
jmduarte Mar 2, 2022
7f209e2
smaller MCU
jmduarte Mar 2, 2022
c424ac1
hardcode name which works for current models :shrug:
jmduarte Mar 2, 2022
6e6730d
generalize with regex
jmduarte Mar 2, 2022
fb3def1
make eembc_power configurable
jmduarte Mar 3, 2022
ed3ffc7
Use SystemVerilog define to control EEMBC setup
Mar 4, 2022
7452d5a
Update axi_master_design.tcl
GiuseppeDiGuglielmo Mar 4, 2022
f4613f0
Update design_1_wrapper.v
GiuseppeDiGuglielmo Mar 4, 2022
9a411f0
Fix but w/ clock connection that stopped the QSPI boot
Mar 9, 2022
5c0ad77
Fix bug w/ clock connection that stopped the QSPI boot
Mar 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example-models
2 changes: 1 addition & 1 deletion hls4ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import absolute_import

__version__ = '0.5.1'
__version__ = '0.6.0'

from hls4ml import converters
from hls4ml import report
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader,

@keras_handler('BatchNormalization')
def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert('BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'])
assert('BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] or 'QDenseBatchnorm' in keras_layer['class_name'])

layer = parse_default_keras_layer(keras_layer, input_names)

Expand Down
9 changes: 9 additions & 0 deletions hls4ml/converters/keras/qkeras_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ def parse_qconv2dbatchnorm_layer(keras_layer, input_names, input_shapes, data_re
temp_shape = intermediate_shape
batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader, config)
return {**conv_layer, **batch_layer}, out_shape

@keras_handler('QDenseBatchnorm')
def parse_qdensebatchnorm_layer(keras_layer, input_names, input_shapes, data_reader, config):
intermediate_shape = list()
dense_layer, shape_qdense = parse_qdense_layer(keras_layer, input_names, input_shapes, data_reader, config)
intermediate_shape.append(shape_qdense)
temp_shape = intermediate_shape
batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader, config)
return {**dense_layer, **batch_layer}, out_shape
74 changes: 72 additions & 2 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(self, shape, dim_names, proxy, **kwargs):
self.shape = shape
self.dim_names = dim_names
self.type = proxy.type
self.cppname = proxy.name
self.name = proxy.name
self.size = proxy.size

Expand Down Expand Up @@ -365,6 +366,7 @@ def __init__(self, model, name, attributes, inputs, outputs=None):
self.set_attr('accum_t', accum_t.precision)
self.reuse_factor = self.model.config.get_reuse_factor(self)
self.target_cycles = self.model.config.get_target_cycles(self)
self.merged_relu = False

layer_config = self.model.config.get_layer_config(self)
for config_key, config_value in layer_config.items():
Expand Down Expand Up @@ -410,6 +412,10 @@ def get_output_variable(self, output_name=None):
else:
return next(iter(self.variables.values()))

def set_output_variable(self, output_name, output_value):
self.variables[output_name] = output_value


def get_weights(self, var_name=None):
if var_name:
return self.weights[var_name]
Expand Down Expand Up @@ -450,6 +456,8 @@ def make_array_variable(self, shape, dim_names, var_name='layer{index}_out', typ

def make_stream_variable(self, shape, dim_names, var_name='layer{index}_out', type_name='layer{index}_t', precision=None, depth=0):
pack_factor = self.model.config.get_layer_config_value(self, 'PackFactor', default=1)
if depth == 0:
depth = self.model.config.get_layer_config_value(self, 'StreamDepth', default=0)

return StreamVariable(shape, dim_names, var_name=var_name, type_name=type_name, precision=precision, n_pack=pack_factor, depth=depth, index=self.index)

Expand Down Expand Up @@ -541,6 +549,12 @@ def _default_config_params(self):
def get_layer_precision(self):
return self.precision

def get_merged_relu(self):
return self.merged_relu

def set_merged_relu(self, merged_relu):
self.merged_relu = merged_relu # Bool flag to set merged_relu

# myproject.cpp/h
def function_cpp(self):
raise NotImplementedError
Expand Down Expand Up @@ -589,7 +603,6 @@ def initialize(self):
out_name = self.outputs[0]
proxy = self.get_input_variable()
out = InplaceVariable(shape, dims, proxy, index=self.get_input_node().index)

self.variables[out_name] = out
self.model.register_output_variable(out_name, out)

Expand Down Expand Up @@ -646,9 +659,61 @@ def config_cpp(self):
params['nonzeros'] = self.get_weights('weight').nonzeros
params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision)
params['strategy'] = self.get_attr('strategy')

params['merged_relu'] = "true" if self.get_merged_relu() else "false"
params['out_t'] = self.get_output_variable().type.name
return self._config_template.format(**params)

class DenseBatchnorm(Dense):
def _get_folded_weights(self):
"""
Function to get the batchnorm folded weights.
This function converts the weights by folding batchnorm parameters into
the weight of QDense. The high-level equation:
W_fold = gamma * W / sqrt(variance + epsilon)
bias_fold = gamma * (bias - moving_mean) / sqrt(variance + epsilon) + beta
"""
kernel = self.model.get_weights_data(self.name, 'kernel')
bias = self.model.get_weights_data(self.name, 'bias')
if bias is None:
bias = 0

# get batchnorm weights and moving stats
gamma = self.model.get_weights_data(self.name, 'gamma')
beta = self.model.get_weights_data(self.name, 'beta')
moving_mean = self.model.get_weights_data(self.name, 'moving_mean')
moving_variance = self.model.get_weights_data(self.name, 'moving_variance')
# get the inversion factor so that we replace division by multiplication
inv = np.reciprocal(np.sqrt(moving_variance + self.get_attr('epsilon')))
if gamma is not None:
inv *= gamma

# wrap conv kernel and bias with bn parameters
folded_kernel = inv * kernel
folded_bias = inv * (bias - moving_mean) + beta

return [folded_kernel, folded_bias]

def initialize(self):
super(DenseBatchnorm, self).initialize()
folded_weights, folded_bias = self._get_folded_weights()
if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']:
self.weights['weight'].data_unquantized = np.transpose(folded_weights)
self.weights['weight'].data = self.get_attr('weight_quantizer')(self.weights['weight'].data_unquantized)

else:
self.weights['weight'].data_unquantized = folded_weights
self.weights['weight'].data = self.get_attr('weight_quantizer')(folded_weights)
self.weights['bias'].data_unquantized = folded_bias
bias_q = self.get_attr('bias_quantizer')
if bias_q is not None:
self.weights['bias'].data = bias_q(folded_bias)

def function_cpp(self):
return super(DenseBatchnorm, self).function_cpp()

def config_cpp(self):
return super(DenseBatchnorm, self).config_cpp()

class Conv1D(Layer):
def initialize(self):
if self.get_attr('data_format') == 'channels_last':
Expand Down Expand Up @@ -854,7 +919,9 @@ def initialize(self):
else:
shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']]
dims = ['N_FILT_{}'.format(self.index), 'OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index)]
self.attributes['intermediate_index'] = self.index
self.add_output_variable(shape, dims)
self.intermediate_op = self.get_output_variable()
self.add_weights(quantizer=self.get_attr('weight_quantizer'))
self.add_bias(quantizer=self.get_attr('bias_quantizer'))
if len(self.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
Expand Down Expand Up @@ -921,6 +988,8 @@ def config_cpp(self):
mult_params['n_in'] = self.get_attr('n_chan') * self.get_attr('filt_height') * self.get_attr('filt_width')
mult_params['n_out'] = self.get_attr('n_filt')
mult_params['product_type'] = self.model.config.backend.product_type(self.get_input_variable().type.precision, self.get_weights('weight').type.precision)
mult_params['merged_relu'] = "true" if self.get_merged_relu() else "false"
mult_params['out_t'] = self.intermediate_op.type.name
mult_config = self._config_template[1].format(**mult_params)

return mult_config + '\n' + conv_config
Expand Down Expand Up @@ -1865,6 +1934,7 @@ def _get_transforms_config(self, params):
'BinaryDense' : Dense,
'TernaryDense' : Dense,
'QDense' : Dense,
'QDenseBatchnorm' : DenseBatchnorm,
'Conv1D' : Conv1D,
'QConv1D' : Conv1D,
'Conv2D' : Conv2D,
Expand Down
12 changes: 12 additions & 0 deletions hls4ml/model/hls_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def get_project_name(self):
def get_output_dir(self):
return self.get_config_value('OutputDir')

def get_merged_relu(self, default=None):
hls_config = self.config['HLSConfig']

model_config = hls_config.get('Model', None)
key = 'MergedRelu'

if model_config is not None:
tempbool = model_config.get(key, default)
return tempbool

return default

def get_layer_config_value(self, layer, key, default=None):
hls_config = self.config['HLSConfig']

Expand Down
2 changes: 2 additions & 0 deletions hls4ml/model/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hls4ml.model.optimizer.passes.conv_same_pad import InsertZeroPaddingBeforeConv2D
from hls4ml.model.optimizer.passes.pointwise import OptimizePointwiseConv
from hls4ml.model.optimizer.passes.clone import CloneOutput
from hls4ml.model.optimizer.passes.relu_merge import MergeRelu
from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, BroadcastStream, RemoveFinalReshape
from hls4ml.model.optimizer.passes.transpose_opt import RemoveUselessTranspose
from hls4ml.model.optimizer.passes.multi_dense import ReplaceMultidimensionalDenseWithConv
Expand Down Expand Up @@ -40,6 +41,7 @@
register_pass('conv2d_same_pad', InsertZeroPaddingBeforeConv2D)
register_pass('optimize_pointwise_conv', OptimizePointwiseConv)
register_pass('clone_output', CloneOutput)
register_pass('relu_merge', MergeRelu)
register_pass('remove_final_reshape', RemoveFinalReshape)
register_pass('reshape_stream', ReshapeStream)
register_pass('remove_useless_transpose', RemoveUselessTranspose)
Expand Down
48 changes: 48 additions & 0 deletions hls4ml/model/optimizer/passes/relu_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from hls4ml.model.optimizer import OptimizerPass

class MergeRelu(OptimizerPass):
def match(self, node):
supported_layers = ['Conv2D', 'Conv2DBatchnorm', 'Dense']
is_match = node.get_input_node().__class__.__name__ in supported_layers

# hls4ml names ReLU activations 'Activation'
is_match = is_match and (node.__class__.__name__ == 'Activation')
return is_match

def transform(self, model, node):
# Merge ReLU and Convolution/Dense layer
previous_node = node.get_input_node()
previous_node.index = node.index
previous_node.set_merged_relu(True) # Turn on merged_relu flag for this Conv/Dense layer
if 'Conv2D' in previous_node.__class__.__name__:
if previous_node.get_attr('data_format') == 'channels_last':
shape = [previous_node.attributes['out_height'], previous_node.attributes['out_width'], previous_node.attributes['n_filt']]
dims = ['OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index), 'N_FILT_{}'.format(previous_node.index)]
else:
shape = [previous_node.attributes['n_filt'], previous_node.attributes['out_height'], previous_node.attributes['out_width']]
dims = ['N_FILT_{}'.format(previous_node.index), 'OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index)]
activation_precision, _ = model.config.get_precision(node, var='result')
previous_node.add_output_variable(shape, dims, precision=activation_precision)
if not node.get_output_nodes():
print("WARNING: {} is the output layer! No rewiring performed.".format(node.name))
model.remove_node(node, rewire=False)
else:
model.remove_node(node, rewire=True)
return True
elif 'Dense' in previous_node.__class__.__name__:
shape = previous_node.get_input_variable().shape[:]
shape[-1] = previous_node.attributes['n_out']
if len(shape) > 1:
dims = ['N_LAYER_{}_{}'.format(i, previous_node.index) for i in range(1, len(shape) + 1)]
else:
dims = ['N_LAYER_{}'.format(previous_node.index)]
print('shape: {}'.format(shape))
print('dims: {}'.format(dims))
activation_precision, _ = model.config.get_precision(node, var='result')
previous_node.add_output_variable(shape, dims, precision=activation_precision)
if not node.get_output_nodes():
print("WARNING: {} is the output layer! No rewiring performed.".format(node.name))
model.remove_node(node, rewire=False)
else:
model.remove_node(node, rewire=True)
return True
98 changes: 98 additions & 0 deletions hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from pyDigitalWaveTools.vcd.parser import VcdParser

import hls4ml
from hls4ml.model.hls_model import HLSModel
from hls4ml.model.hls_layers import IntegerPrecisionType, FixedPrecisionType
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -26,6 +29,101 @@
__torch_profiling_enabled__ = False


def optimize_fifos_depth(hls_model, init_large_fifo=True, reset=True, csim=True, synth=True,
cosim=True, validation=True, export=True, vsynth=True, **kwargs,):

cfg = hls_model.config.config.copy()
hls_config = cfg['HLSConfig']
out_dir = hls_model.config.get_output_dir()

values = []

def populate_values(name, data, depth):
values.append({'name': name, 'data': [], 'max': 0, 'depth': 0})
get_values = lambda x: int(x[1][1:], 2)
values[-1]['data'] = [get_values(x) for x in data]
values[-1]['max'] = max(values[-1]['data'])
values[-1]['depth'] = int(depth[1:], 2)

if not hls_config['Model']['FIFO_opt']:
raise Exception('To use this optimization you have to set `FIFO_opt` field to True in the HLS config')


# initialize all the fifos to 10000 so that they will be automatically implemented in BRAMs and so they will be
# profiled

if init_large_fifo:

for k,_ in hls_model.output_vars.items():
if k not in hls_config['LayerName']:
hls_config['LayerName'][k] = {'StreamDepth': 10000}
else:
hls_config['LayerName'][k]['StreamDepth'] = 10000

if hls_model.config.get_config_value('Backend') == 'VivadoAccelerator':
hls_config['LayerName']['in_local'] = {'StreamDepth' : 10000}
hls_config['LayerName']['out_local'] = {'StreamDepth': 10000}

cfg['OutputDir'] = out_dir + "_LARGE_FIFO"
cfg['HLSConfig'] = hls_config
hls_model = hls4ml.converters.keras_to_hls(cfg)


# run the build with FIFO_opt param set to 1 in order to generate the vcd file
hls_model.write()
hls_model.build(csim=True, cosim=True, synth=True, vsynth=False, export=False, validation=True)

with open(hls_model.config.get_output_dir() + '/' + hls_model.config.get_project_name() + '_prj' + '/solution1/sim/verilog/fifo_opt.vcd') as vcd_file:
vcd = VcdParser()
vcd.parse(vcd_file)
data = vcd.scope.toJson()

# wrapper fifos - useful only with VivadoAccelerator backend
if hls_model.config.get_config_value('Backend') == 'VivadoAccelerator':
for i in range(1, len(data['children'][0]['children'][0]['children'])):
populate_values(data['children'][0]['children'][0]['children'][i]['name'],
data['children'][0]['children'][0]['children'][i]['children'][0]['data'],
data['children'][0]['children'][0]['children'][i]['children'][1]['data'][0][1])

# model layers fifos
n_elem = len(data['children'][0]['children'][0]['children'][0]['children'])
for i in range(n_elem):
populate_values(data['children'][0]['children'][0]['children'][0]['children'][i]['name'],
data['children'][0]['children'][0]['children'][0]['children'][i]['children'][0]['data'],
data['children'][0]['children'][0]['children'][0]['children'][i]['children'][1]['data'][0][1])

maxs = [{'name': i['name'], 'max': i['max'], 'depth': i['depth']} for i in values]

with open(hls_model.config.get_output_dir() + '/max_depth.json', 'w') as f:
json.dump(maxs, f, indent=4)

new_config = cfg.copy()['HLSConfig']
new_config['Model']['FIFO_opt'] = 0
for k, v in hls_model.output_vars.items():
filtered_max = [x['max'] for x in maxs if v.cppname in x['name']]
if len(filtered_max) == 0:
continue
if len(filtered_max) > 1:
print('WARNING! Check names of FIFOs')
if k not in new_config['LayerName']:
new_config['LayerName'][k] = {'StreamDepth': filtered_max[0] + 1}
else:
new_config['LayerName'][k]['StreamDepth'] = filtered_max[0] + 1
for x in maxs:
if 'in_local' in x['name']:
new_config['LayerName']['in_local'] = {'StreamDepth': x['max'] + 1}
elif 'out_local' in x['name']:
new_config['LayerName']['out_local'] = {'StreamDepth': x['max'] + 1}

cfg['OutputDir'] = out_dir + '_FIFO_OPT'
cfg['HLSConfig'] = new_config
hls_model = hls4ml.converters.keras_to_hls(cfg)
hls_model.write()
hls_model.build(reset=reset, csim=csim, synth=synth, cosim=cosim, validation=validation, export=export, vsynth=vsynth)
print('[hls4ml] - FIFO optimization completed')
return hls_model


def get_unoptimized_hlsmodel(model):
from hls4ml.converters import convert_from_config

Expand Down
Loading