From e077161738bd091de3560e866bc06d0f5a6ba56b Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Tue, 28 Nov 2023 15:49:59 -0500 Subject: [PATCH 01/47] playing with brevitas --- hls4ml/converters/pytorch/_symbolic_trace.py | 1118 ++++++++++++++++++ hls4ml/converters/pytorch/convolution.py | 37 +- hls4ml/converters/pytorch/core.py | 4 +- hls4ml/converters/pytorch/pooling.py | 21 +- hls4ml/converters/pytorch_to_hls.py | 19 +- test_brevitas.py | 145 +++ 6 files changed, 1325 insertions(+), 19 deletions(-) create mode 100644 hls4ml/converters/pytorch/_symbolic_trace.py create mode 100644 test_brevitas.py diff --git a/hls4ml/converters/pytorch/_symbolic_trace.py b/hls4ml/converters/pytorch/_symbolic_trace.py new file mode 100644 index 0000000000..1d2ccafdfc --- /dev/null +++ b/hls4ml/converters/pytorch/_symbolic_trace.py @@ -0,0 +1,1118 @@ +import builtins +import copy +import functools +import inspect +import math +import os +import warnings +import collections +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] + +from ._compatibility import compatibility +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: Dict[Type, None] = {} + +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_lnotab, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: Set[int] = { + id(value) + for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: List[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + # Mapping of node name to module scope + self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, (torch.Tensor, ScriptObject)): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + i = 0 + while True: + qualname = f"_tensor_constant{i}" + if not hasattr(self.root, qualname): + break + i += 1 + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + i = 0 + while True: + qualname = f"_{a.__class__.__name__}_constant_{i}" + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn") or m.__module__.startswith("brevitas.nn")) + and not isinstance(m, torch.nn.Sequential) + ) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + self.module_stack[_scope.module_path] = _scope.module_type + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + key, _ = self.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: List[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + def proxy_placeholder(name: str): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default = ( + () + if param.default is inspect.Parameter.empty + else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if x == PH: + return out + # Union[int, bool] == bool in Python <= 3.6 + if ( + type(x) == bool + or type(x) in base_types + and type(x) != torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif type(x) == type(None): + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default = () + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None) + ) + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + self.root = root + + assert hasattr( + type(root), self.traced_func_name + ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: Dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k in {'_autowrap_search'}: + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + +# List of pairs of (global dict, function name) functions +# to patch for the purposes of the wrap() API. +_wrapped_fns_to_patch: List[Tuple[dict, str]] = [] + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: List[Tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + + def revert(self): + raise NotImplementedError() + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + +class _Patcher: + def __init__(self): + super().__init__() + self.patches_made: List[_PatchedFn] = [] + self.visited: Set[int] = set() + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name]) + ) + frame_dict[name] = new_fn + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) + setattr(cls, name, new_fn) + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for frame_dict, name in _wrapped_fns_to_patch: + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + torch.fx.wrap('my_custom_function') + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch.append((f.f_globals, fn_name)) + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a*2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={'b': False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) + assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return GraphModule(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index c17de86e9d..d10aca4843 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,8 +1,25 @@ from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format - - -@pytorch_handler('Conv1d') +from hls4ml.model.types import FixedPrecisionType +import math + +def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point): + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = - math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (fract_bitwidth, int_bitwidth) + +@pytorch_handler('Conv1d', 'QuantConv1d') def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Conv1d' in operation @@ -47,7 +64,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -@pytorch_handler('Conv2d') +@pytorch_handler('Conv2d', 'QuantConv2d') def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Conv2d' in operation @@ -57,8 +74,16 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['class_name'] = 'Conv2D' layer['data_format'] = 'channels_first' # Pytorch default (can't change) - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + if "Quant" in operation: + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + width = class_object.quant_weight().bit_width + ap_fixed_params = ConvUAQToAp_Fixed(width, class_object.quant_weight().scale,0) + layer['precision'] = FixedPrecisionType(width=width, integer=ap_fixed_params[1], signed=True) + + else: + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') # Input info (layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format( input_shapes[0], 'channels_first' diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 5a2caba1ea..67a374c527 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,7 +1,7 @@ from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler -@pytorch_handler('Linear') +@pytorch_handler('Linear', 'QuantLinear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation @@ -29,7 +29,7 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh'] +activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh','QuantReLU','QuantSigmoid','QuantTanh'] @pytorch_handler(*activation_layers) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 3076f1f38a..fbeca5407c 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -1,7 +1,7 @@ from hls4ml.converters.pytorch_to_hls import pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format -pooling_layers = ['MaxPool1d', 'MaxPool2d', 'AvgPool1d', 'AvgPool2d'] +pooling_layers = ['MaxPool1d', 'MaxPool2d', 'AvgPool1d', 'AvgPool2d', 'QuantMaxPool1d', 'QuantMaxPool2d'] #TODO add support for special quantized average pool layers @pytorch_handler(*pooling_layers) @@ -85,18 +85,25 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, padding = [class_object.padding, class_object.padding] else: - if type(node.kwargs['stride']) is tuple: + if node.kwargs['stride'] is None: + if type(node.args[-1]) is tuple: + layer['stride_height'] = node.args[-1][0] + layer['stride_width'] = node.args[-1][0] + else: + layer['stride_height'] = node.args[-1] + layer['stride_width'] = node.args[-1] + elif type(node.kwargs['stride']) is tuple: layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1] else: layer['stride_height'] = node.kwargs['stride'] layer['stride_width'] = node.kwargs['stride'] - if type(node.kwargs['kernel_size']) is tuple: - layer['pool_height'] = node.kwargs['kernel_size'][0] - layer['pool_width'] = node.kwargs['kernel_size'][1] + if type(node.args[-1]) is tuple: + layer['pool_height'] = node.args[-1][0] + layer['pool_width'] = node.args[-1][0] else: - layer['pool_height'] = node.kwargs['kernel_size'] - layer['pool_width'] = node.kwargs['kernel_size'] + layer['pool_height'] = node.args[-1] + layer['pool_width'] = node.args[-1] if type(node.kwargs['padding']) is tuple: padding = node.kwargs['padding'] diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 961fb735af..c66de48bf7 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -3,6 +3,16 @@ from hls4ml.model import ModelGraph +class CustomFXTracer(torch.fx.Tracer): + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX + """ + return ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn") or m.__module__.startswith("brevitas.nn")) + and not isinstance(m, torch.nn.Sequential) + ) class PyTorchModelReader: """ @@ -24,6 +34,7 @@ def get_weights_data(self, layer_name, var_name): # if a layer is reused in the model, torch.FX will append a "_n" for the n-th use # have to snap that off to find the tensors + print (self.state_dict) if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1: layer_name = '_'.join(layer_name.split('_')[:-1]) @@ -140,9 +151,9 @@ def pytorch_to_hls(config): # dict of layer objects in non-traced form for access lateron children = {c[0]: c[1] for c in model.named_children()} # use symbolic_trace to get a full graph of the model - from torch.fx import symbolic_trace - traced_model = symbolic_trace(model) + tracer = CustomFXTracer() + traced_model = tracer.trace(model) # Define layers to skip for conversion to HLS skip_layers = ['Dropout', 'Flatten', 'Sequential'] @@ -160,8 +171,8 @@ def pytorch_to_hls(config): layer_counter = 0 n_inputs = 0 - - for node in traced_model.graph.nodes: + print (traced_model) + for node in traced_model.nodes: # If part of a nn.Sequntial, the node name will start with an "_" which messes up the parsing if node.name[0] == '_': node.name = 'layer' + node.name diff --git a/test_brevitas.py b/test_brevitas.py new file mode 100644 index 0000000000..755d09bae2 --- /dev/null +++ b/test_brevitas.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import Module +import torch.nn.functional as F + +import brevitas.nn as qnn +from brevitas.quant import Int8WeightPerTensorFixedPoint + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +import math + + + + +""" +conversion function from target ap_fixed to parameters for UAQ +""" +from typing import Tuple + +def ConvAp_FixedToUAQ(int_bitwidth, fract_bitwidth) -> Tuple[int,float,float]: + """ + parameters: + int_bitwidth: int + fract_bitwidth: int + + return: + bitwidth: int + scale_factor: float + zero_point: float + """ + bitwidth = int_bitwidth + fract_bitwidth + scale_factor = 2**(-fract_bitwidth) + zero_point = 0 # we assume int representation is signed + + return (bitwidth, scale_factor, zero_point) + + + +""" +conversion function from UAQ to ap_fixed +""" +from typing import Tuple + +def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point) -> Tuple[int,int]: + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = - math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (bitwidth, int_bitwidth) + + + +class QuantWeightLeNet(Module): + def __init__(self): + super(QuantWeightLeNet, self).__init__() + self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + self.relu1 = nn.ReLU() + self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4) + self.relu2 = nn.ReLU() + self.fc1 = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4) + self.relu3 = nn.ReLU() + self.fc2 = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4) + self.relu4 = nn.ReLU() + self.fc3 = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4) + + def forward(self, x): + out = self.relu1(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = self.relu2(self.conv2(out)) + out = F.max_pool2d(out, 2) + #out = out.reshape(out.reshape[0], -1) + out = self.relu3(self.fc1(out)) + out = self.relu4(self.fc2(out)) + out = self.fc3(out) + return out + +quant_weight_lenet = QuantWeightLeNet() + + +class QuantModel(Module): + def __init__(self): + super(QuantModel, self).__init__() + self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + self.relu1 = nn.ReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + +quant_weight_lenet = QuantWeightLeNet() + +model = QuantModel() + + +x = torch.randn(3,6,5) + +quant_linear = qnn.QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint, bias=False) +print(f"Weight QuantTensor Linear:\n {quant_linear.quant_weight()}") +print(f"Quant Weight fix point: {- math.log2(quant_linear.quant_weight().scale)}") +print(f"Quant Weight scale: {quant_linear.quant_weight().scale}") +print(f"Quant Weight bit width: {quant_linear.quant_weight().bit_width}") +print(f"Quant Weight zero point: {quant_linear.quant_weight().zero_point}") + +pytorch_prediction = model(x).detach().numpy() +print(f"Weight Tensor:\n {model.conv1.weight}") +print(f"Weight QuantTensor:\n {model.conv1.quant_weight()}") +print(f"Quant Weight fix point: {- math.log2(model.conv1.quant_weight().scale)}") +print(f"Quant Weight scale: {model.conv1.quant_weight().scale}") +print(f"Quant Weight bit width: {model.conv1.quant_weight().bit_width}") +print(f"Quant Weight zero point: {model.conv1.quant_weight().zero_point}") +ap_fixed_params = ConvUAQToAp_Fixed(8, model.conv1.quant_weight().scale,0) +print (ap_fixed_params) +config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) +#config['Model']['Precision'] = 'ap_fixed<%d,%d>'%(ap_fixed_params[0],ap_fixed_params[1]) +print (config) +output_dir = "test_pytorch" +backend = "Vivado" +io_type = 'io_parallel' + +hls_model = convert_from_pytorch_model( + model, + (None, 3,6,5), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, +) +hls_model.compile() + +hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) +print(pytorch_prediction) +print(hls_prediction) \ No newline at end of file From c946a03339060e616eca0ed122baf83790d2203d Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 15 Dec 2023 10:35:33 -0500 Subject: [PATCH 02/47] add brevitas quantizer --- hls4ml/converters/pytorch/_symbolic_trace.py | 1118 ------------------ hls4ml/converters/pytorch/convolution.py | 23 +- hls4ml/converters/pytorch_to_hls.py | 2 - hls4ml/model/types.py | 14 + 4 files changed, 31 insertions(+), 1126 deletions(-) delete mode 100644 hls4ml/converters/pytorch/_symbolic_trace.py diff --git a/hls4ml/converters/pytorch/_symbolic_trace.py b/hls4ml/converters/pytorch/_symbolic_trace.py deleted file mode 100644 index 1d2ccafdfc..0000000000 --- a/hls4ml/converters/pytorch/_symbolic_trace.py +++ /dev/null @@ -1,1118 +0,0 @@ -import builtins -import copy -import functools -import inspect -import math -import os -import warnings -import collections -from itertools import chain -from types import CodeType, FunctionType, ModuleType -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - Tuple, - Type, - Union, -) - -import torch -import torch.utils._pytree as pytree -from torch._C import ScriptObject # type: ignore[attr-defined] - -from ._compatibility import compatibility -from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph -from .graph_module import GraphModule -from .node import Argument, base_types, map_aggregate -from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager - -HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS - -# These need to run in global scope to handle nested calls correctly -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ - -_proxyable_classes: Dict[Type, None] = {} - -_is_fx_tracing_flag = False - - -def is_fx_tracing(): - return _is_fx_tracing_flag - -@compatibility(is_backward_compatible=True) -class ProxyableClassMeta(type): - """ - ProxyableClassMeta allows you to make construction of a given Python class - symbolically traceable. For example:: - - import torch - import torch.fx - - class TensorPair(metaclass=torch.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): - s = x.add(TensorPair(y, y)) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = torch.randn(5, 3) - ref_out = use_tensor_pair_ctor(x, y) - - traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) - print(traced.code) - ''' - def forward(self, x : __main___TensorPair, y : torch.Tensor): - tensor_pair = __main___TensorPair(y, y); y = None - add = x.add(tensor_pair); tensor_pair = None - mul = add.mul(x); add = x = None - return mul - ''' - - From this example, we can see that construction of a class (``TensorPair``) - defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic - tracing. - """ - - def __init__(cls, name, bases, attrs): - _proxyable_classes.setdefault(cls) - super().__init__(name, bases, attrs) - - def __call__(cls, *args, **kwargs): - instance = cls.__new__(cls) # type: ignore[call-overload] - - found_proxies = [] - - def check_proxy(a): - if isinstance(a, Proxy): - found_proxies.append(a) - - map_aggregate(args, check_proxy) - map_aggregate(kwargs, check_proxy) - - if len(found_proxies) != 0: - tracer = found_proxies[0].tracer - return tracer.create_proxy("call_function", cls, args, kwargs) - else: - cls.__init__(instance, *args, **kwargs) # type: ignore[misc] - return instance - - -def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: - co = fn.__code__ - co_flags = co.co_flags & ~HAS_VARSTUFF - co_args: tuple - if hasattr(co, "co_qualname"): - # Python-3.11+ code signature - co_args = ( - nargs, - 0, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_qualname, # type: ignore[attr-defined] - co.co_firstlineno, - co.co_lnotab, - co.co_exceptiontable, # type: ignore[attr-defined] - co.co_freevars, - co.co_cellvars, - ) - elif hasattr(co, "co_posonlyargcount"): - co_args = ( - nargs, - 0, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - else: - co_args = ( - nargs, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - new_code = CodeType(*co_args) # type: ignore[arg-type] - return FunctionType( - new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ - ) - - # we need to insert placeholder nodes for *args and **kwargs - # we can't call this function normally, otherwise it would try to unpack them - # instead, let's make python think that args and kwargs are normal variables - - -@compatibility(is_backward_compatible=False) -class PHBase: - """ - Object representing an input placeholder to `concrete_args` - """ - - def __repr__(self): - return "PH" - - -PH = PHBase() - - -@compatibility(is_backward_compatible=True) -class Tracer(TracerBase): - # Reference: https://github.com/pytorch/pytorch/issues/54354 - # The first line of this docstring overrides the one Sphinx generates for the - # documentation. We need it so that Sphinx doesn't leak `math`s path from the - # build environment (e.g. ` None: - # This method's signature is overridden by the first line of this class' - # docstring. If this method's signature is modified, the signature that - # overrides it also should be modified accordingly. - - """ - Construct a Tracer object. - - Args: - - autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, - Python modules whose functions should be wrapped automatically - without needing to use fx.wrap(). Backward-compatibility for - this parameter is guaranteed. - - autowrap_functions (Tuple[Callable, ...]): defaults to `()`, - Python functions that should be wrapped automatically without - needing to use fx.wrap(). Backward compatibility for this - parameter is guaranteed. - - param_shapes_constant (bool): When this flag is set, calls to shape, - size and a few other shape like attributes of a module's parameter - will be evaluated directly, rather than returning a new Proxy value - for an attribute access. Backward compatibility for this parameter - is guaranteed. - """ - - super().__init__() - - # Functions we will eagerly wrap when we see them while tracing - # this captures both `math.sqrt()` and `from math import sqrt` automatically - self._autowrap_function_ids: Set[int] = { - id(value) - for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) - if not name.startswith("_") and callable(value) - } - self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) - - # Python modules to apply autowrap to at the start, in addition to - # modules we see while tracing - self._autowrap_search: List[ModuleType] = list(autowrap_modules) - self.param_shapes_constant = param_shapes_constant - - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None - self.root_module_name: str = "" - # Maps the containing module's name to the operator name - self.scope = Scope("", None) - # Records the module call stack - self.module_stack = collections.OrderedDict() - # Mapping of node name to module scope - self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} - - @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> "Argument": - """ - A method to specify the behavior of tracing when preparing values to - be used as arguments to nodes in the ``Graph``. - - By default, the behavior includes: - - #. Iterate through collection types (e.g. tuple, list, dict) and recursively - call ``create_args`` on the elements. - #. Given a Proxy object, return a reference to the underlying IR ``Node`` - #. Given a non-Proxy Tensor object, emit IR for various cases: - - * For a Parameter, emit a ``get_attr`` node referring to that Parameter - * For a non-Parameter Tensor, store the Tensor away in a special - attribute referring to that attribute. - - This method can be overridden to support more types. - - Args: - - a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. - - - Returns: - - The value ``a`` converted into the appropriate ``Argument`` - """ - # The base tracer is used to construct Graphs when there is no associated - # module hierarchy, so it can never create parameter references. - # The default tracer adds the ability to refer to parameters when - # tracing modules. - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node("get_attr", n, (), {}) - raise NameError("parameter is not a member of this module") - elif isinstance(a, torch.Tensor): - for n_, p_ in self.root.named_buffers(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - elif isinstance(a, torch.nn.Module): - for n_, p_ in self.root.named_modules(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - # For NamedTuple instances that appear literally as args, we emit - # a node to construct the NamedTuple and use that Node as the argument. - if isinstance(a, tuple) and hasattr(a, "_fields"): - args = tuple(self.create_arg(elem) for elem in a) - return self.create_node("call_function", a.__class__, args, {}) - - # Tensors do not have a reliable string repr() from which they can be - # constructed (and we probably don't want to rely on that, either), so - # for any constant Tensor values we encounter, first search for if they - # are an attribute of some module in the module hierarchy. If so, emit - # a get_attr to retrieve that tensor. Otherwise, we'll store away the - # tensor value into a special attribute on the Module s.t. we can - # retrieve it with a get_attr. - if isinstance(a, (torch.Tensor, ScriptObject)): - qualname: Optional[str] = self.tensor_attrs.get(a) - - # Tensor was not found in the Module hierarchy, stow it away in a - # special attribute and set the qualname to refer to that - if not qualname: - i = 0 - while True: - qualname = f"_tensor_constant{i}" - if not hasattr(self.root, qualname): - break - i += 1 - self.tensor_attrs[a] = qualname - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - if type(a) in _proxyable_classes: - # This is an instance of a proxyable class for which we did not - # witness its construction. Intern this as a constant attribute - - # TODO: binary search - i = 0 - while True: - qualname = f"_{a.__class__.__name__}_constant_{i}" - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - return super().create_arg(a) - - @compatibility(is_backward_compatible=True) - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - """ - A method to specify whether a given ``nn.Module`` is a "leaf" module. - - Leaf modules are the atomic units that appear in - the IR, referenced by ``call_module`` calls. By default, - Modules in the PyTorch standard library namespace (torch.nn) - are leaf modules. All other modules are traced through and - their constituent ops are recorded, unless specified otherwise - via this parameter. - - Args: - - m (Module): The module being queried about - module_qualified_name (str): The path to root of this module. For example, - if you have a module hierarchy where submodule ``foo`` contains - submodule ``bar``, which contains submodule ``baz``, that module will - appear with the qualified name ``foo.bar.baz`` here. - """ - return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn") or m.__module__.startswith("brevitas.nn")) - and not isinstance(m, torch.nn.Sequential) - ) - - @compatibility(is_backward_compatible=True) - def path_of_module(self, mod: torch.nn.Module) -> str: - """ - Helper method to find the qualified name of ``mod`` in the Module hierarchy - of ``root``. For example, if ``root`` has a submodule named ``foo``, which has - a submodule named ``bar``, passing ``bar`` into this function will return - the string "foo.bar". - - Args: - - mod (str): The ``Module`` to retrieve the qualified name for. - """ - # Prefer the O(1) algorithm - if self.submodule_paths: - path = self.submodule_paths.get(mod) - if path is None: - raise NameError("module is not installed as a submodule") - assert isinstance(path, str) - return path - # O(N^2) fallback in the case that we didn't store the submodule - # paths. - else: - for n, p in self.root.named_modules(): - if mod is p: - return n - raise NameError("module is not installed as a submodule") - - @compatibility(is_backward_compatible=True) - def call_module( - self, - m: torch.nn.Module, - forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ) -> Any: - """ - Method that specifies the behavior of this ``Tracer`` when it encounters - a call to an ``nn.Module`` instance. - - By default, the behavior is to check if the called module is a leaf module - via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to - ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through - the operations in its ``forward`` function. - - This method can be overridden to--for example--create nested traced - GraphModules, or any other behavior you would want while tracing across - ``Module`` boundaries. - - Args: - - m (Module): The module for which a call is being emitted - forward (Callable): The forward() method of the ``Module`` to be invoked - args (Tuple): args of the module callsite - kwargs (Dict): kwargs of the module callsite - - Return: - - The return value from the Module call. In the case that a ``call_module`` - node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever - value was returned from the ``Module`` invocation. - """ - module_qualified_name = self.path_of_module(m) - with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: - # module_stack is an ordered dict so writing then deleting the - # entry is equivalent to push/pop on a list - self.module_stack[_scope.module_path] = _scope.module_type - if not self.is_leaf_module(m, module_qualified_name): - ret_val = forward(*args, **kwargs) - else: - ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) - key, _ = self.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" - - return ret_val - - @compatibility(is_backward_compatible=False) - def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): - """ - Method that specifies the behavior of this ``Tracer`` when we call getattr - on a call to an ``nn.Module`` instance. - - By default, the behavior is to return a proxy value for the attribute. It - also stores the proxy value in the ``parameter_proxy_cache``, so that future - calls will reuse the proxy rather than creating a new one. - - This method can be overridden to --for example-- not return proxies when - querying parameters. - - Args: - - attr (str): The name of the attribute being queried - attr_val (Any): The value of the attribute - parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies - - Return: - - The return value from the getattr call. - """ - def maybe_get_proxy_for_attr( - attr_val, collection_to_search, parameter_proxy_cache - ): - for n, p in collection_to_search: - if attr_val is p: - if n not in parameter_proxy_cache: - kwargs = {} - if ( - "proxy_factory_fn" - in inspect.signature(self.create_proxy).parameters - ): - kwargs["proxy_factory_fn"] = ( - None - if not self.param_shapes_constant - else lambda node: ParameterProxy( - self, node, n, attr_val - ) - ) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] - parameter_proxy_cache[n] = val_proxy - return parameter_proxy_cache[n] - return None - - if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_parameters(), parameter_proxy_cache - ) - if maybe_parameter_proxy is not None: - return maybe_parameter_proxy - - if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_buffers(), parameter_proxy_cache - ) - if maybe_buffer_proxy is not None: - return maybe_buffer_proxy - - return attr_val - - # This method will be refactored - @compatibility(is_backward_compatible=False) - def create_args_for_root(self, root_fn, is_module, concrete_args=None): - """ - Create ``placeholder`` nodes corresponding to the signature of the ``root`` - Module. This method introspects root's signature and emits those - nodes accordingly, also supporting ``*args`` and ``**kwargs``. - """ - # In some cases, a function or method has been decorated with a wrapper - # defined via ``functools.wraps``. In this case, the outer code object - # will likely not contain the actual parameters we care about, so unwrap - # the function to get to the innermost callable. - fn_for_analysis = inspect.unwrap(root_fn) - co = fn_for_analysis.__code__ - total_args = co.co_argcount + co.co_kwonlyargcount - orig_args = list(co.co_varnames) - names_iter = iter(co.co_varnames) - args: List[Any] = [] - skip_arg_idx = 0 - if is_module: - if total_args == 0: - raise RuntimeError( - "``self`` argument cannot be part of *args expansion!" - ) - skip_arg_idx = 1 - next(names_iter) # skip self - args.append(self.root) - - sig = inspect.signature(fn_for_analysis) - - def proxy_placeholder(name: str): - if concrete_args is not None and name in concrete_args: - cnt = 0 - - def replace_ph(x): - nonlocal cnt - cnt += 1 - param = sig.parameters[name] - default = ( - () - if param.default is inspect.Parameter.empty - else (param.default,) - ) - out = self.create_proxy( - "placeholder", f"{name}_{str(cnt)}", default, {} - ) - if x == PH: - return out - # Union[int, bool] == bool in Python <= 3.6 - if ( - type(x) == bool - or type(x) in base_types - and type(x) != torch.Tensor - ): - torch._assert( - out == x, - f"{name} has been specialized to have value {x} but got another value", - ) - elif type(x) == type(None): - args = ( - out, - f"{name} has been specialized to have value None but got another value", - ) - self.create_proxy("call_function", _assert_is_none, args, {}) - else: - warnings.warn( - f"Was not able to add assertion to guarantee correct input {name} to " - f"specialized function. It is up to the user to make sure that your inputs match the " - f"inputs you specialized the function with." - ) - - return x - - return pytree.tree_map(replace_ph, concrete_args[name]) - if name[0] == "*": - default = () - else: - param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] - return self.create_proxy( - "placeholder", - name, - default, - {}, - type_expr=fn_for_analysis.__annotations__.get(name, None) - ) - - arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] - if isinstance(concrete_args, tuple): - if len(arg_names) != len(concrete_args): - raise RuntimeError( - f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" - ) - concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} - args.extend(proxy_placeholder(names) for names in arg_names) - - if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: - # TODO: type annotations for *args and **kwargs - if co.co_flags & inspect.CO_VARARGS: - args.append(proxy_placeholder("*" + next(names_iter))) - if co.co_flags & inspect.CO_VARKEYWORDS: - args.append(proxy_placeholder("**" + next(names_iter))) - root_fn = _patch_function(root_fn, len(args)) - - flat_args, in_spec = pytree.tree_flatten(tuple(args)) - if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): - # In the case that we have pytree-flattened inputs in - # `concrete_args`, generate a flattening wrapper around the - # original root function and return that. - self.graph._codegen = _PyTreeCodeGen( - _PyTreeInfo(orig_args[:total_args], in_spec, None) - ) - - def flatten_fn(*args): - tree_args = pytree.tree_unflatten(list(args), in_spec) - tree_out = root_fn(*tree_args) - out_args, out_spec = pytree.tree_flatten(tree_out) - assert isinstance(self.graph._codegen, _PyTreeCodeGen) - self.graph._codegen.pytree_info = ( - self.graph._codegen.pytree_info._replace(out_spec=out_spec) - ) - return out_args - - return flatten_fn, flat_args - return root_fn, args - - @compatibility(is_backward_compatible=True) - def trace( - self, - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, - ) -> Graph: - """ - Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` - can either be an ``nn.Module`` instance or a Python callable. - - Note that after this call, ``self.root`` may be different from the ``root`` passed - in here. For example, when a free function is passed to ``trace()``, we will - create an ``nn.Module`` instance to use as the root and add embedded constants - to. - - - Args: - - root (Union[Module, Callable]): Either a ``Module`` or a function to be - traced through. Backwards-compatibility for this parameter is - guaranteed. - concrete_args (Optional[Dict[str, any]]): Concrete arguments that should - not be treated as Proxies. This parameter is experimental and - its backwards-compatibility is *NOT* guaranteed. - - Returns: - - A ``Graph`` representing the semantics of the passed-in ``root``. - """ - global _is_fx_tracing_flag - old_is_fx_tracing_flag = _is_fx_tracing_flag - _is_fx_tracing_flag = True - try: - if isinstance(root, torch.nn.Module): - self.root = root - - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" - - fn = getattr(type(root), self.traced_func_name) - self.root_module_name = root._get_name() - self.submodule_paths = {mod: name for name, mod in root.named_modules()} - else: - self.root = torch.nn.Module() - fn = root - - tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None) - self.graph = Graph(tracer_cls=tracer_cls) - - # When we encounter a Tensor value that's not a parameter, we look if it - # is some other attribute on the model. Construct a dict mapping Tensor - # values to the qualified name here for efficiency. This is used downstream - # in create_arg - self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): - for k, v in m.__dict__.items(): - if isinstance(v, (torch.Tensor, ScriptObject)): - self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) - for k, v in m.named_children(): - collect_tensor_attrs(v, prefix_atoms + [k]) - - collect_tensor_attrs(self.root, []) - - assert isinstance(fn, FunctionType) - - fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root( - fn, isinstance(root, torch.nn.Module), concrete_args - ) - - parameter_proxy_cache: Dict[ - str, Proxy - ] = {} # Reduce number of get_attr calls - - # Method dispatch on parameters is not recorded unless it's directly used. - # Thus, we need to insert a proxy when __getattr__ requests a parameter. - @functools.wraps(_orig_module_getattr) - def module_getattr_wrapper(mod, attr): - attr_val = _orig_module_getattr(mod, attr) - return self.getattr(attr, attr_val, parameter_proxy_cache) - - @functools.wraps(_orig_module_call) - def module_call_wrapper(mod, *args, **kwargs): - def forward(*args, **kwargs): - return _orig_module_call(mod, *args, **kwargs) - - _autowrap_check( - patcher, - getattr(getattr(mod, "forward", mod), "__globals__", {}), - self._autowrap_function_ids, - ) - return self.call_module(mod, forward, args, kwargs) - - with _Patcher() as patcher: - # allow duplicate patches to support the case of nested calls - patcher.patch_method( - torch.nn.Module, - "__getattr__", - module_getattr_wrapper, - deduplicate=False, - ) - patcher.patch_method( - torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False - ) - _patch_wrapped_functions(patcher) - _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) - for module in self._autowrap_search: - _autowrap_check( - patcher, module.__dict__, self._autowrap_function_ids - ) - self.create_node( - "output", - "output", - (self.create_arg(fn(*args)),), - {}, - type_expr=fn.__annotations__.get("return", None), - ) - - self.submodule_paths = None - finally: - _is_fx_tracing_flag = old_is_fx_tracing_flag - return self.graph - - def __deepcopy__(self, memo): - # _autowrap_search contains modules, which cannot be deepcopied. - new_tracer = Tracer.__new__(Tracer) - - for k, v in self.__dict__.items(): - if k in {'_autowrap_search'}: - new_obj = copy.copy(v) - else: - new_obj = copy.deepcopy(v, memo) - - new_tracer.__dict__[k] = new_obj - - return new_tracer - - -# List of pairs of (global dict, function name) functions -# to patch for the purposes of the wrap() API. -_wrapped_fns_to_patch: List[Tuple[dict, str]] = [] - -# List of methods on classes to wrap (class type, function name) -# this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch: List[Tuple[type, str]] = [] - -if os.environ.get("FX_PATCH_GETITEM") == "1": - # This change is needed to trace models like PositionalEmbedding from BERT: - # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py - # but causes issues in quantization documented here: - # https://github.com/pytorch/pytorch/issues/50710 - # once that is fixed we can make this the default behavior. - _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) - - -def _find_proxy(*objects_to_search): - """ - Recursively search a data structure for a Proxy() and return it, - return None if not found. - """ - proxy = None - - def find_proxy(x): - nonlocal proxy - if isinstance(x, Proxy): - proxy = x - - map_aggregate(objects_to_search, find_proxy) - return proxy - - -def _create_wrapped_func(orig_fn): - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Given an closed-over ``orig_function`` to invoke, search the args and kwargs for - a Proxy object. If there is one, emit a ``call_function`` node to preserve the - call to this leaf function directly. Otherwise, just return the results of - this function call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return_proxy = proxy.tracer.create_proxy( - "call_function", orig_fn, args, kwargs - ) - return_proxy.node.meta["is_wrapped"] = True - return return_proxy - return orig_fn(*args, **kwargs) - - return wrapped - - -def _create_wrapped_method(cls, name): - orig_fn = getattr(cls, name) - - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Search the args and kwargs for a Proxy object. If there is one, - emit a ``call_method`` node to preserve the call to this method - directly. Otherwise, just return the results of this function - call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return proxy.tracer.create_proxy("call_method", name, args, kwargs) - return orig_fn(*args, **kwargs) - - return wrapped - - -class _PatchedFn(NamedTuple): - frame_dict: Any - fn_name: str - orig_fn: Any - - def revert(self): - raise NotImplementedError() - - -class _PatchedFnSetItem(_PatchedFn): - def revert(self): - self.frame_dict[self.fn_name] = self.orig_fn - - -class _PatchedFnDel(_PatchedFn): - def revert(self): - del self.frame_dict[self.fn_name] - - -class _PatchedFnSetAttr(_PatchedFn): - def revert(self): - setattr(self.frame_dict, self.fn_name, self.orig_fn) - - -class _Patcher: - def __init__(self): - super().__init__() - self.patches_made: List[_PatchedFn] = [] - self.visited: Set[int] = set() - - def patch( - self, - frame_dict: Dict[str, Any], - name: str, - new_fn: Callable, - deduplicate: bool = True, - ): - """ - Replace frame_dict[name] with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - if name not in frame_dict and hasattr(builtins, name): - self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) - elif getattr(frame_dict[name], "__fx_already_patched", False): - return # already patched, no need to do it again - else: - self.patches_made.append( - _PatchedFnSetItem(frame_dict, name, frame_dict[name]) - ) - frame_dict[name] = new_fn - - def patch_method( - self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True - ): - """ - Replace object_or_dict.name with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - orig_fn = getattr(cls, name) - if getattr(orig_fn, "__fx_already_patched", False): - return # already patched, no need to do it again - self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) - setattr(cls, name, new_fn) - - def visit_once(self, thing: Any): - """Return True on the first call to with thing, otherwise false""" - idx = id(thing) - if idx in self.visited: - return False - self.visited.add(idx) - return True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Undo all the changes made via self.patch() and self.patch_method() - """ - while self.patches_made: - # unpatch in reverse order to handle duplicates correctly - self.patches_made.pop().revert() - self.visited.clear() - - -def _patch_wrapped_functions(patcher: _Patcher): - """ - Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap - the listed global functions in the `_create_wrapped_func` wrapper. - """ - for frame_dict, name in _wrapped_fns_to_patch: - if name not in frame_dict and hasattr(builtins, name): - orig_fn = getattr(builtins, name) - else: - orig_fn = frame_dict[name] - patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) - - for cls, name in _wrapped_methods_to_patch: - patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) - - -def _autowrap_check( - patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] -): - """ - Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. - This method searches a scope for them and patches them if found. - """ - if patcher.visit_once(frame_dict): - for name, value in frame_dict.items(): - if ( - not name.startswith("_") - and callable(value) - and id(value) in function_ids - ): - patcher.patch(frame_dict, name, _create_wrapped_func(value)) - - -@compatibility(is_backward_compatible=True) -def wrap(fn_or_name: Union[str, Callable]): - """ - This function can be called at module-level scope to register fn_or_name as a "leaf function". - A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being - traced through:: - - # foo/bar/baz.py - def my_custom_function(x, y): - return x * x + y * y - - torch.fx.wrap('my_custom_function') - - def fn_to_be_traced(x, y): - # When symbolic tracing, the below call to my_custom_function will be inserted into - # the graph rather than tracing it. - return my_custom_function(x, y) - - This function can also equivalently be used as a decorator:: - - # foo/bar/baz.py - @torch.fx.wrap - def my_custom_function(x, y): - return x * x + y * y - - A wrapped function can be thought of a "leaf function", analogous to the concept of - "leaf modules", that is, they are functions that are left as calls in the FX trace - rather than traced through. - - Args: - - fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the - graph when it's called - """ - if not callable(fn_or_name) and not isinstance(fn_or_name, str): - raise RuntimeError( - "Unsupported type for global function! Must be either a callable or " - "string name" - ) - - if callable(fn_or_name): - assert not isinstance(fn_or_name, str) # to make mypy happy - fn_name = fn_or_name.__name__ - else: - assert isinstance( - fn_or_name, str - ), "fn_or_name must be a global function or string name" - fn_name = fn_or_name - - currentframe = inspect.currentframe() - assert currentframe is not None - f = currentframe.f_back - assert f is not None - if f.f_code.co_name != "": - raise NotImplementedError("wrap must be called at the top level of a module") - - # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search - # semantics would be slightly different, but would add support `from x import wrapped_function` - _wrapped_fns_to_patch.append((f.f_globals, fn_name)) - return fn_or_name - - -@compatibility(is_backward_compatible=True) -def symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, -) -> GraphModule: - """ - Symbolic tracing API - - Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` - constructed by recording operations seen while tracing through ``root``. - - ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. - - For example:: - - def f(a, b): - if b == True: - return a - else: - return a*2 - - FX can typically not trace through this due to the presence of control - flow. However, we can use `concrete_args` to specialize on the value of - `b` to trace through this:: - - f = fx.symbolic_trace(f, concrete_args={'b': False}) - assert f(3, False) == 6 - - Note that although you can still pass in different values of `b`, they will be ignored. - - We can also use `concrete_args` to eliminate data-structure handling from - our function. This will use pytrees to flatten your input. To avoid - overspecializing, pass in `fx.PH` for values that shouldn't be - specialized. For example:: - - def f(x): - out = 0 - for v in x.values(): - out += v - return out - f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) - assert f({'a': 1, 'b': 2, 'c': 4}) == 7 - - - Args: - root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted - into a Graph representation. - concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized - - Returns: - GraphModule: a Module created from the recorded operations from ``root``. - """ - tracer = Tracer() - graph = tracer.trace(root, concrete_args) - name = ( - root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - ) - return GraphModule(tracer.root, graph, name) - - -@wrap -def _assert_is_none(value, msg): - assert value is None, msg diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index d10aca4843..55b6f59272 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,6 +1,6 @@ from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format -from hls4ml.model.types import FixedPrecisionType +from hls4ml.model.types import FixedPrecisionType, BrevitasQuantizer import math def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point): @@ -75,11 +75,22 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['data_format'] = 'channels_first' # Pytorch default (can't change) if "Quant" in operation: - layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['bias_data'] = class_object.quant_bias().detach().value.numpy() - width = class_object.quant_weight().bit_width - ap_fixed_params = ConvUAQToAp_Fixed(width, class_object.quant_weight().scale,0) - layer['precision'] = FixedPrecisionType(width=width, integer=ap_fixed_params[1], signed=True) + + if class_object.is_weight_quant_enabled: + width = int(class_object.quant_weight().bit_width) + ap_fixed_params = ConvUAQToAp_Fixed(width, float(class_object.quant_weight().scale),0) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + + if class_object.is_bias_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = ConvUAQToAp_Fixed(width, float(class_object.quant_bias().scale),0) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') else: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index c66de48bf7..c6c4150610 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -34,7 +34,6 @@ def get_weights_data(self, layer_name, var_name): # if a layer is reused in the model, torch.FX will append a "_n" for the n-th use # have to snap that off to find the tensors - print (self.state_dict) if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1: layer_name = '_'.join(layer_name.split('_')[:-1]) @@ -171,7 +170,6 @@ def pytorch_to_hls(config): layer_counter = 0 n_inputs = 0 - print (traced_model) for node in traced_model.nodes: # If part of a nn.Sequntial, the node name will start with an "_" which messes up the parsing if node.name[0] == '_': diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index b6f2e42a01..e97d3df7cd 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -163,7 +163,21 @@ def __call__(self, data): y = y.numpy() return y +class BrevitasQuantizer(Quantizer): + """Wrapper around brevitas quantizers. Since we can get the already quantized tensors + directly from the brevitas QuantTensor objects, nothing needs to be done + Args: + bits: bitwidth of the quantized tensor + hls_type: hls_type of the quantized tensor + """ + + def __init__(self, bits, hls_type): + super().__init__(bits, hls_type) + + def __call__(self, data): + return data + # endregion # region Precision types From 52bb225f10d33df3e8f5c57bc8031dfa65f68058 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 9 Feb 2024 10:50:38 -0500 Subject: [PATCH 03/47] latest brevitas developments --- hls4ml/converters/pytorch/convolution.py | 47 +++++++++++------------- hls4ml/converters/pytorch/core.py | 33 +++++++++++++++-- hls4ml/converters/pytorch_to_hls.py | 16 ++++++++ test_brevitas.py | 20 ++++++++-- 4 files changed, 84 insertions(+), 32 deletions(-) diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index 55b6f59272..f4c2692b7b 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,23 +1,6 @@ -from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler +from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler, convert_uaq_to_apfixed from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format from hls4ml.model.types import FixedPrecisionType, BrevitasQuantizer -import math - -def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point): - """ - parameters: - bitwidth: int - scale_factor: float - zero_point: float - - return: - int_bitwidth: int - fract_bitwidth: int - """ - fract_bitwidth = - math.log2(scale_factor) - int_bitwidth = bitwidth - fract_bitwidth - - return (fract_bitwidth, int_bitwidth) @pytorch_handler('Conv1d', 'QuantConv1d') def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): @@ -29,8 +12,25 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c layer['class_name'] = 'Conv1D' layer['data_format'] = 'channels_first' # Pytorch default (can't change) - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + if "Quant" in operation: + if class_object.is_weight_quant_enabled: + width = int(class_object.quant_weight().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + + if class_object.is_bias_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + else: + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') # Input info (layer['in_width'], layer['n_chan']) = parse_data_format( input_shapes[0], 'channels_first' @@ -75,10 +75,9 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['data_format'] = 'channels_first' # Pytorch default (can't change) if "Quant" in operation: - if class_object.is_weight_quant_enabled: width = int(class_object.quant_weight().bit_width) - ap_fixed_params = ConvUAQToAp_Fixed(width, float(class_object.quant_weight().scale),0) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) else: @@ -86,12 +85,11 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c if class_object.is_bias_quant_enabled: width = int(class_object.quant_bias().bit_width) - ap_fixed_params = ConvUAQToAp_Fixed(width, float(class_object.quant_bias().scale),0) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) else: layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') - else: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') @@ -109,7 +107,6 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['dilation'] = class_object.dilation[0] layer['pad_top'] = layer['pad_bottom'] = class_object.padding[0] layer['pad_left'] = layer['pad_right'] = class_object.padding[1] - if all(x == 0 for x in class_object.padding): # No padding, i.e., 'VALID' padding in Keras/Tensorflow layer['padding'] = 'valid' else: # Only 'valid' and 'same' padding are available in Keras diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 67a374c527..1e780daa49 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,4 +1,5 @@ -from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler +from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler, convert_uaq_to_apfixed +from hls4ml.model.types import FixedPrecisionType, BrevitasQuantizer @pytorch_handler('Linear', 'QuantLinear') @@ -10,7 +11,25 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c layer['class_name'] = 'Dense' layer['name'] = layer_name - layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias']) + if "Quant" in operation: + if class_object.is_weight_quant_enabled: + width = int(class_object.quant_weight().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + + if class_object.is_bias_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + else: + layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + else: + layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias']) + if class_object is not None: layer['n_in'] = class_object.in_features layer['n_out'] = class_object.out_features @@ -40,8 +59,13 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['activation'] = layer['class_name'] layer['name'] = layer_name - # if layer['class_name'] != 'Activation': - # layer['activation'] = layer['class_name'] + if "Quant" in operation: + layer['class_name'] = operation.split('Quant')[-1] + layer['activation'] = layer['class_name'] + bit_width = class_object.quant_act_bit_width() + ap_fixed_params = convert_uaq_to_apfixed(bit_width,class_object.quant_act_scale()) + layer['activation_quantizer'] = BrevitasQuantizer(bit_width,FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=True)) + if node.op == 'call_module': if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid': layer['class_name'] = 'Activation' @@ -77,6 +101,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['axis'] = node.kwargs['dim'] output_shape = input_shapes[0] + print (layer) return layer, output_shape diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index c6c4150610..d406ea60f1 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,5 +1,6 @@ import numpy as np import torch +import math from hls4ml.model import ModelGraph @@ -78,6 +79,21 @@ def get_weights_data(data_reader, layer_name, var_name): else: return (*data,) +def convert_uaq_to_apfixed(bitwidth, scale_factor): + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = - math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (fract_bitwidth, int_bitwidth) # ----------------------Layer handling--------------------- # layer_handlers = {} diff --git a/test_brevitas.py b/test_brevitas.py index 755d09bae2..2e6a8dc4a4 100644 --- a/test_brevitas.py +++ b/test_brevitas.py @@ -100,12 +100,26 @@ def forward(self, x): out = self.relu1(self.conv1(x)) return out +class LinearModel(Module): + def __init__(self): + super(LinearModel, self).__init__() + self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint) + #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + #self.relu1 = nn.ReLU() + self.relu1 = qnn.QuantReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + quant_weight_lenet = QuantWeightLeNet() -model = QuantModel() +#model = QuantModel() +model = LinearModel() -x = torch.randn(3,6,5) +#x = torch.randn(3,6,5) +x = torch.tensor([1.,2.,3.,4.]) quant_linear = qnn.QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint, bias=False) print(f"Weight QuantTensor Linear:\n {quant_linear.quant_weight()}") @@ -132,7 +146,7 @@ def forward(self, x): hls_model = convert_from_pytorch_model( model, - (None, 3,6,5), + (None, 4), hls_config=config, output_dir=output_dir, backend=backend, From b212f8e01e8ce8acab04657cd3edd0cdb122d5af Mon Sep 17 00:00:00 2001 From: Fotis Giasemis Date: Fri, 1 Mar 2024 14:50:41 +0100 Subject: [PATCH 04/47] Avoid Y2K22 Xilinx bug --- hls4ml/templates/vivado/build_prj.tcl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index d34337c573..f7b757bc71 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -229,7 +229,7 @@ if {$opt(validation)} { if {$opt(export)} { puts "***** EXPORT IP *****" set time_start [clock clicks -milliseconds] - export_design -format ip_catalog + export_design -format ip_catalog -version "1.0.0" set time_end [clock clicks -milliseconds] report_time "EXPORT IP" $time_start $time_end } From bf37179ee5036265b41d8c76cf30f40783487865 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 20 May 2024 16:24:44 -0400 Subject: [PATCH 05/47] state of the art brevitas parsing and add pytest --- hls4ml/converters/pytorch/core.py | 2 +- test/pytest/test_brevitas_parsing.py | 99 ++++++++++++++++ test_brevitas.py | 5 +- test_brevitas_conv.py | 161 +++++++++++++++++++++++++++ 4 files changed, 264 insertions(+), 3 deletions(-) create mode 100644 test/pytest/test_brevitas_parsing.py create mode 100644 test_brevitas_conv.py diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 1e780daa49..0d4b4e0b5d 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -64,7 +64,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['activation'] = layer['class_name'] bit_width = class_object.quant_act_bit_width() ap_fixed_params = convert_uaq_to_apfixed(bit_width,class_object.quant_act_scale()) - layer['activation_quantizer'] = BrevitasQuantizer(bit_width,FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=True)) + layer['activation_quantizer'] = BrevitasQuantizer(bit_width,FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False)) if node.op == 'call_module': if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid': diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py new file mode 100644 index 0000000000..8dd423121b --- /dev/null +++ b/test/pytest/test_brevitas_parsing.py @@ -0,0 +1,99 @@ +import math +from pathlib import Path + +import torch +from torch import nn +from torch.nn import Module +import torch.nn.functional as F + +import brevitas.nn as qnn +from brevitas.quant import Int8WeightPerTensorFixedPoint + +import numpy as np +import pytest + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +test_root_path = Path(__file__).parent + +class QuantModelConv2d(Module): + def __init__(self): + super(QuantModelConv2d, self).__init__() + self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.relu1 = nn.ReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + +class QuantModelLinear(Module): + def __init__(self): + super(QuantModelLinear, self).__init__() + self.conv1 = qnn.QuantLinear(4, 4, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.relu1 = qnn.QuantReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_quantlinear(backend, io_type): + model = QuantModelLinear() + + x = torch.tensor([1.,2.,3.,4.]) + + pytorch_prediction = model(x).detach().numpy() + config = config_from_pytorch_model(model) + output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model( + model, + (None, 4), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_quantconv2d(backend, io_type): + model = QuantModelConv2d() + + x = torch.randn(1,3,6,5) + + pytorch_prediction = model(x).detach().numpy() + config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) + if io_type == 'io_stream': + x = np.ascontiguousarray(x.transpose(0, 2, 3, 1)) + config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) + else: + config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) + + output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model( + model, + (None, 3,6,5), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + hls_model.compile() + + if io_type == 'io_stream': + hls_prediction = np.transpose( + np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape), (0, 3, 1, 2) + ) + else: + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) diff --git a/test_brevitas.py b/test_brevitas.py index 2e6a8dc4a4..12186363c0 100644 --- a/test_brevitas.py +++ b/test_brevitas.py @@ -116,7 +116,7 @@ def forward(self, x): #model = QuantModel() model = LinearModel() - +#model.eval() #x = torch.randn(3,6,5) x = torch.tensor([1.,2.,3.,4.]) @@ -156,4 +156,5 @@ def forward(self, x): hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) print(pytorch_prediction) -print(hls_prediction) \ No newline at end of file +print(hls_prediction) +print(torch.__version__) diff --git a/test_brevitas_conv.py b/test_brevitas_conv.py new file mode 100644 index 0000000000..e2beed6514 --- /dev/null +++ b/test_brevitas_conv.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import Module +import torch.nn.functional as F + +import brevitas.nn as qnn +from brevitas.quant import Int8WeightPerTensorFixedPoint + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +import math + + + + +""" +conversion function from target ap_fixed to parameters for UAQ +""" +from typing import Tuple + +def ConvAp_FixedToUAQ(int_bitwidth, fract_bitwidth) -> Tuple[int,float,float]: + """ + parameters: + int_bitwidth: int + fract_bitwidth: int + + return: + bitwidth: int + scale_factor: float + zero_point: float + """ + bitwidth = int_bitwidth + fract_bitwidth + scale_factor = 2**(-fract_bitwidth) + zero_point = 0 # we assume int representation is signed + + return (bitwidth, scale_factor, zero_point) + + + +""" +conversion function from UAQ to ap_fixed +""" +from typing import Tuple + +def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point) -> Tuple[int,int]: + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = - math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (bitwidth, int_bitwidth) + + + +class QuantWeightLeNet(Module): + def __init__(self): + super(QuantWeightLeNet, self).__init__() + self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + self.relu1 = nn.ReLU() + self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4) + self.relu2 = nn.ReLU() + self.fc1 = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4) + self.relu3 = nn.ReLU() + self.fc2 = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4) + self.relu4 = nn.ReLU() + self.fc3 = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4) + + def forward(self, x): + out = self.relu1(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = self.relu2(self.conv2(out)) + out = F.max_pool2d(out, 2) + #out = out.reshape(out.reshape[0], -1) + out = self.relu3(self.fc1(out)) + out = self.relu4(self.fc2(out)) + out = self.fc3(out) + return out + +quant_weight_lenet = QuantWeightLeNet() + + +class QuantModel(Module): + def __init__(self): + super(QuantModel, self).__init__() + #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.conv1 = nn.Conv2d(3, 6, 5, bias=True) + #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + self.relu1 = nn.ReLU() + + def forward(self, x): + #out = self.relu1(self.conv1(x)) + out = self.conv1(x) + return out + +class LinearModel(Module): + def __init__(self): + super(LinearModel, self).__init__() + self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint) + #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) + #self.relu1 = nn.ReLU() + self.relu1 = qnn.QuantReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + +quant_weight_lenet = QuantWeightLeNet() + +model = QuantModel() +#model = LinearModel() + + +x = torch.randn(3,6,5) +#x = torch.tensor([1.,2.,3.,4.]) + +quant_linear = qnn.QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint, bias=False) +print(f"Weight QuantTensor Linear:\n {quant_linear.quant_weight()}") +print(f"Quant Weight fix point: {- math.log2(quant_linear.quant_weight().scale)}") +print(f"Quant Weight scale: {quant_linear.quant_weight().scale}") +print(f"Quant Weight bit width: {quant_linear.quant_weight().bit_width}") +print(f"Quant Weight zero point: {quant_linear.quant_weight().zero_point}") + +pytorch_prediction = model(x).detach().numpy() +# print(f"Weight Tensor:\n {model.conv1.weight}") +# print(f"Weight QuantTensor:\n {model.conv1.quant_weight()}") +# print(f"Quant Weight fix point: {- math.log2(model.conv1.quant_weight().scale)}") +# print(f"Quant Weight scale: {model.conv1.quant_weight().scale}") +# print(f"Quant Weight bit width: {model.conv1.quant_weight().bit_width}") +# print(f"Quant Weight zero point: {model.conv1.quant_weight().zero_point}") +# ap_fixed_params = ConvUAQToAp_Fixed(8, model.conv1.quant_weight().scale,0) +# print (ap_fixed_params) +config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) +#config['Model']['Precision'] = 'ap_fixed<%d,%d>'%(ap_fixed_params[0],ap_fixed_params[1]) +print (config) +output_dir = "test_pytorch" +backend = "Vivado" +io_type = 'io_parallel' + +hls_model = convert_from_pytorch_model( + model, + (None, 3,6,5), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, +) +hls_model.compile() + +hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) +print(pytorch_prediction) +print(hls_prediction) \ No newline at end of file From 5316a482c62f30fce004a8cf27e25dc1b6fc7719 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Tue, 21 May 2024 11:22:43 -0400 Subject: [PATCH 06/47] fix some compilation errors --- .../templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h | 1 + .../templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h | 1 + 2 files changed, 2 insertions(+) diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h index a110d6d424..974bb95807 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv1d_resource.h @@ -3,6 +3,7 @@ #include "nnet_common.h" #include "nnet_dense.h" +#include namespace nnet { diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h index 73ad45592f..c1c14d53b0 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_conv2d_resource.h @@ -4,6 +4,7 @@ #include "nnet_common.h" #include "nnet_dense.h" #include "nnet_helpers.h" +#include namespace nnet { From 6e47e9a12b9166b195ccb9223cdae30f7fd3fcbd Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Tue, 21 May 2024 11:41:27 -0400 Subject: [PATCH 07/47] fix another trivial error in pytests --- test/pytest/test_brevitas_parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 8dd423121b..5b29ed414d 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -72,7 +72,7 @@ def test_quantconv2d(backend, io_type): pytorch_prediction = model(x).detach().numpy() config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) if io_type == 'io_stream': - x = np.ascontiguousarray(x.transpose(0, 2, 3, 1)) + x = np.ascontiguousarray(x.permute(0, 2, 3, 1)) config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) else: config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) @@ -91,7 +91,7 @@ def test_quantconv2d(backend, io_type): if io_type == 'io_stream': hls_prediction = np.transpose( - np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape), (0, 3, 1, 2) + np.reshape(hls_model.predict(x), pytorch_prediction.shape), (0, 3, 1, 2) ) else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) From f2201b00ad20906d9408be81e7b25968c344c84b Mon Sep 17 00:00:00 2001 From: simon71701 <54042406+simon71701@users.noreply.github.com> Date: Wed, 22 May 2024 14:43:00 -0400 Subject: [PATCH 08/47] Delete test_brevitas.py remove testing script --- test_brevitas.py | 160 ----------------------------------------------- 1 file changed, 160 deletions(-) delete mode 100644 test_brevitas.py diff --git a/test_brevitas.py b/test_brevitas.py deleted file mode 100644 index 12186363c0..0000000000 --- a/test_brevitas.py +++ /dev/null @@ -1,160 +0,0 @@ -import numpy as np -import torch -from torch import nn -from torch.nn import Module -import torch.nn.functional as F - -import brevitas.nn as qnn -from brevitas.quant import Int8WeightPerTensorFixedPoint - -from hls4ml.converters import convert_from_pytorch_model -from hls4ml.utils.config import config_from_pytorch_model - -import math - - - - -""" -conversion function from target ap_fixed to parameters for UAQ -""" -from typing import Tuple - -def ConvAp_FixedToUAQ(int_bitwidth, fract_bitwidth) -> Tuple[int,float,float]: - """ - parameters: - int_bitwidth: int - fract_bitwidth: int - - return: - bitwidth: int - scale_factor: float - zero_point: float - """ - bitwidth = int_bitwidth + fract_bitwidth - scale_factor = 2**(-fract_bitwidth) - zero_point = 0 # we assume int representation is signed - - return (bitwidth, scale_factor, zero_point) - - - -""" -conversion function from UAQ to ap_fixed -""" -from typing import Tuple - -def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point) -> Tuple[int,int]: - """ - parameters: - bitwidth: int - scale_factor: float - zero_point: float - - return: - int_bitwidth: int - fract_bitwidth: int - """ - fract_bitwidth = - math.log2(scale_factor) - int_bitwidth = bitwidth - fract_bitwidth - - return (bitwidth, int_bitwidth) - - - -class QuantWeightLeNet(Module): - def __init__(self): - super(QuantWeightLeNet, self).__init__() - self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - self.relu1 = nn.ReLU() - self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4) - self.relu2 = nn.ReLU() - self.fc1 = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4) - self.relu3 = nn.ReLU() - self.fc2 = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4) - self.relu4 = nn.ReLU() - self.fc3 = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4) - - def forward(self, x): - out = self.relu1(self.conv1(x)) - out = F.max_pool2d(out, 2) - out = self.relu2(self.conv2(out)) - out = F.max_pool2d(out, 2) - #out = out.reshape(out.reshape[0], -1) - out = self.relu3(self.fc1(out)) - out = self.relu4(self.fc2(out)) - out = self.fc3(out) - return out - -quant_weight_lenet = QuantWeightLeNet() - - -class QuantModel(Module): - def __init__(self): - super(QuantModel, self).__init__() - self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) - #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - self.relu1 = nn.ReLU() - - def forward(self, x): - out = self.relu1(self.conv1(x)) - return out - -class LinearModel(Module): - def __init__(self): - super(LinearModel, self).__init__() - self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint) - #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - #self.relu1 = nn.ReLU() - self.relu1 = qnn.QuantReLU() - - def forward(self, x): - out = self.relu1(self.conv1(x)) - return out - -quant_weight_lenet = QuantWeightLeNet() - -#model = QuantModel() -model = LinearModel() -#model.eval() - -#x = torch.randn(3,6,5) -x = torch.tensor([1.,2.,3.,4.]) - -quant_linear = qnn.QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint, bias=False) -print(f"Weight QuantTensor Linear:\n {quant_linear.quant_weight()}") -print(f"Quant Weight fix point: {- math.log2(quant_linear.quant_weight().scale)}") -print(f"Quant Weight scale: {quant_linear.quant_weight().scale}") -print(f"Quant Weight bit width: {quant_linear.quant_weight().bit_width}") -print(f"Quant Weight zero point: {quant_linear.quant_weight().zero_point}") - -pytorch_prediction = model(x).detach().numpy() -print(f"Weight Tensor:\n {model.conv1.weight}") -print(f"Weight QuantTensor:\n {model.conv1.quant_weight()}") -print(f"Quant Weight fix point: {- math.log2(model.conv1.quant_weight().scale)}") -print(f"Quant Weight scale: {model.conv1.quant_weight().scale}") -print(f"Quant Weight bit width: {model.conv1.quant_weight().bit_width}") -print(f"Quant Weight zero point: {model.conv1.quant_weight().zero_point}") -ap_fixed_params = ConvUAQToAp_Fixed(8, model.conv1.quant_weight().scale,0) -print (ap_fixed_params) -config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) -#config['Model']['Precision'] = 'ap_fixed<%d,%d>'%(ap_fixed_params[0],ap_fixed_params[1]) -print (config) -output_dir = "test_pytorch" -backend = "Vivado" -io_type = 'io_parallel' - -hls_model = convert_from_pytorch_model( - model, - (None, 4), - hls_config=config, - output_dir=output_dir, - backend=backend, - io_type=io_type, -) -hls_model.compile() - -hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) -print(pytorch_prediction) -print(hls_prediction) -print(torch.__version__) From f64fab6e97f552bef5801626d18add0a4de5db80 Mon Sep 17 00:00:00 2001 From: simon71701 <54042406+simon71701@users.noreply.github.com> Date: Wed, 22 May 2024 14:43:16 -0400 Subject: [PATCH 09/47] Delete test_brevitas_conv.py remove testing script --- test_brevitas_conv.py | 161 ------------------------------------------ 1 file changed, 161 deletions(-) delete mode 100644 test_brevitas_conv.py diff --git a/test_brevitas_conv.py b/test_brevitas_conv.py deleted file mode 100644 index e2beed6514..0000000000 --- a/test_brevitas_conv.py +++ /dev/null @@ -1,161 +0,0 @@ -import numpy as np -import torch -from torch import nn -from torch.nn import Module -import torch.nn.functional as F - -import brevitas.nn as qnn -from brevitas.quant import Int8WeightPerTensorFixedPoint - -from hls4ml.converters import convert_from_pytorch_model -from hls4ml.utils.config import config_from_pytorch_model - -import math - - - - -""" -conversion function from target ap_fixed to parameters for UAQ -""" -from typing import Tuple - -def ConvAp_FixedToUAQ(int_bitwidth, fract_bitwidth) -> Tuple[int,float,float]: - """ - parameters: - int_bitwidth: int - fract_bitwidth: int - - return: - bitwidth: int - scale_factor: float - zero_point: float - """ - bitwidth = int_bitwidth + fract_bitwidth - scale_factor = 2**(-fract_bitwidth) - zero_point = 0 # we assume int representation is signed - - return (bitwidth, scale_factor, zero_point) - - - -""" -conversion function from UAQ to ap_fixed -""" -from typing import Tuple - -def ConvUAQToAp_Fixed(bitwidth, scale_factor, zero_point) -> Tuple[int,int]: - """ - parameters: - bitwidth: int - scale_factor: float - zero_point: float - - return: - int_bitwidth: int - fract_bitwidth: int - """ - fract_bitwidth = - math.log2(scale_factor) - int_bitwidth = bitwidth - fract_bitwidth - - return (bitwidth, int_bitwidth) - - - -class QuantWeightLeNet(Module): - def __init__(self): - super(QuantWeightLeNet, self).__init__() - self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - self.relu1 = nn.ReLU() - self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4) - self.relu2 = nn.ReLU() - self.fc1 = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4) - self.relu3 = nn.ReLU() - self.fc2 = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4) - self.relu4 = nn.ReLU() - self.fc3 = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4) - - def forward(self, x): - out = self.relu1(self.conv1(x)) - out = F.max_pool2d(out, 2) - out = self.relu2(self.conv2(out)) - out = F.max_pool2d(out, 2) - #out = out.reshape(out.reshape[0], -1) - out = self.relu3(self.fc1(out)) - out = self.relu4(self.fc2(out)) - out = self.fc3(out) - return out - -quant_weight_lenet = QuantWeightLeNet() - - -class QuantModel(Module): - def __init__(self): - super(QuantModel, self).__init__() - #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) - self.conv1 = nn.Conv2d(3, 6, 5, bias=True) - #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - self.relu1 = nn.ReLU() - - def forward(self, x): - #out = self.relu1(self.conv1(x)) - out = self.conv1(x) - return out - -class LinearModel(Module): - def __init__(self): - super(LinearModel, self).__init__() - self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint) - #self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4) - #self.relu1 = nn.ReLU() - self.relu1 = qnn.QuantReLU() - - def forward(self, x): - out = self.relu1(self.conv1(x)) - return out - -quant_weight_lenet = QuantWeightLeNet() - -model = QuantModel() -#model = LinearModel() - - -x = torch.randn(3,6,5) -#x = torch.tensor([1.,2.,3.,4.]) - -quant_linear = qnn.QuantLinear(2, 4, weight_quant=Int8WeightPerTensorFixedPoint, bias=False) -print(f"Weight QuantTensor Linear:\n {quant_linear.quant_weight()}") -print(f"Quant Weight fix point: {- math.log2(quant_linear.quant_weight().scale)}") -print(f"Quant Weight scale: {quant_linear.quant_weight().scale}") -print(f"Quant Weight bit width: {quant_linear.quant_weight().bit_width}") -print(f"Quant Weight zero point: {quant_linear.quant_weight().zero_point}") - -pytorch_prediction = model(x).detach().numpy() -# print(f"Weight Tensor:\n {model.conv1.weight}") -# print(f"Weight QuantTensor:\n {model.conv1.quant_weight()}") -# print(f"Quant Weight fix point: {- math.log2(model.conv1.quant_weight().scale)}") -# print(f"Quant Weight scale: {model.conv1.quant_weight().scale}") -# print(f"Quant Weight bit width: {model.conv1.quant_weight().bit_width}") -# print(f"Quant Weight zero point: {model.conv1.quant_weight().zero_point}") -# ap_fixed_params = ConvUAQToAp_Fixed(8, model.conv1.quant_weight().scale,0) -# print (ap_fixed_params) -config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) -#config['Model']['Precision'] = 'ap_fixed<%d,%d>'%(ap_fixed_params[0],ap_fixed_params[1]) -print (config) -output_dir = "test_pytorch" -backend = "Vivado" -io_type = 'io_parallel' - -hls_model = convert_from_pytorch_model( - model, - (None, 3,6,5), - hls_config=config, - output_dir=output_dir, - backend=backend, - io_type=io_type, -) -hls_model.compile() - -hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) -print(pytorch_prediction) -print(hls_prediction) \ No newline at end of file From 45954f9b126a678efd698b0e04ef266feb8b8d6f Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 7 Jun 2024 14:20:19 -0400 Subject: [PATCH 10/47] fix dimensions in Conv2D pytest for brevitas parsing --- test/pytest/test_brevitas_parsing.py | 46 +++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 5b29ed414d..072581a66d 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -67,6 +67,12 @@ def test_quantlinear(backend, io_type): def test_quantconv2d(backend, io_type): model = QuantModelConv2d() + n_in = 3 + n_out = 6 + kernel_size = 5 + size_in_width = 5 + size_in_height = 6 + x = torch.randn(1,3,6,5) pytorch_prediction = model(x).detach().numpy() @@ -79,6 +85,44 @@ def test_quantconv2d(backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + from hls4ml.converters.pytorch_to_hls import CustomFXTracer + + tracer = CustomFXTracer() + traced_model = tracer.trace(model) + + nNodes = 0 + convNode = None + for _node in traced_model.nodes: + nNodes += 1 + if nNodes == 2: + convNode = _node + + children = {c[0]: c[1] for c in model.named_children()} + class_object_conv = children[convNode.target] + + out_width = int( + ( + size_in_width + + 2 * class_object_conv.padding[1] + - class_object_conv.dilation[1] * (class_object_conv.kernel_size[1] - 1) + - 1 + ) + / class_object_conv.stride[1] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + out_height = int( + ( + size_in_height + + 2 * class_object_conv.padding[0] + - class_object_conv.dilation[0] * (class_object_conv.kernel_size[0] - 1) + - 1 + ) + / class_object_conv.stride[0] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + + hls_model = convert_from_pytorch_model( model, (None, 3,6,5), @@ -91,7 +135,7 @@ def test_quantconv2d(backend, io_type): if io_type == 'io_stream': hls_prediction = np.transpose( - np.reshape(hls_model.predict(x), pytorch_prediction.shape), (0, 3, 1, 2) + np.reshape(hls_model.predict(x), (1,out_height, out_width, n_out)), (0, 3, 1, 2) ) else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) From 73af4c13fb2bfa1ff81f2d35c9c4809dda5e340d Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 7 Jun 2024 14:28:49 -0400 Subject: [PATCH 11/47] trigger pre-commit --- test/pytest/test_brevitas_parsing.py | 40 ++++++++++++---------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 072581a66d..c1aeb14bf5 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -1,25 +1,22 @@ -import math from pathlib import Path -import torch -from torch import nn -from torch.nn import Module -import torch.nn.functional as F - import brevitas.nn as qnn -from brevitas.quant import Int8WeightPerTensorFixedPoint - import numpy as np import pytest +import torch +from brevitas.quant import Int8WeightPerTensorFixedPoint +from torch import nn +from torch.nn import Module from hls4ml.converters import convert_from_pytorch_model from hls4ml.utils.config import config_from_pytorch_model test_root_path = Path(__file__).parent + class QuantModelConv2d(Module): def __init__(self): - super(QuantModelConv2d, self).__init__() + super().__init__() self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) self.relu1 = nn.ReLU() @@ -27,9 +24,10 @@ def forward(self, x): out = self.relu1(self.conv1(x)) return out + class QuantModelLinear(Module): def __init__(self): - super(QuantModelLinear, self).__init__() + super().__init__() self.conv1 = qnn.QuantLinear(4, 4, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) self.relu1 = qnn.QuantReLU() @@ -37,12 +35,13 @@ def forward(self, x): out = self.relu1(self.conv1(x)) return out + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_quantlinear(backend, io_type): model = QuantModelLinear() - x = torch.tensor([1.,2.,3.,4.]) + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) pytorch_prediction = model(x).detach().numpy() config = config_from_pytorch_model(model) @@ -59,24 +58,23 @@ def test_quantlinear(backend, io_type): hls_model.compile() hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_quantconv2d(backend, io_type): model = QuantModelConv2d() - n_in = 3 n_out = 6 - kernel_size = 5 size_in_width = 5 size_in_height = 6 - x = torch.randn(1,3,6,5) + x = torch.randn(1, 3, 6, 5) pytorch_prediction = model(x).detach().numpy() - config = config_from_pytorch_model(model, inputs_channel_last=False,transpose_outputs=True) + config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) if io_type == 'io_stream': x = np.ascontiguousarray(x.permute(0, 2, 3, 1)) config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) @@ -121,11 +119,9 @@ def test_quantconv2d(backend, io_type): + 1 ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - - hls_model = convert_from_pytorch_model( model, - (None, 3,6,5), + (None, 3, 6, 5), hls_config=config, output_dir=output_dir, backend=backend, @@ -134,10 +130,8 @@ def test_quantconv2d(backend, io_type): hls_model.compile() if io_type == 'io_stream': - hls_prediction = np.transpose( - np.reshape(hls_model.predict(x), (1,out_height, out_width, n_out)), (0, 3, 1, 2) - ) + hls_prediction = np.transpose(np.reshape(hls_model.predict(x), (1, out_height, out_width, n_out)), (0, 3, 1, 2)) else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) From 4f401ed1d7b9aacb3b2fb21b8d5632c64169a244 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 7 Jun 2024 14:36:58 -0400 Subject: [PATCH 12/47] move quantizer to new file --- hls4ml/converters/pytorch/convolution.py | 22 ++++++++++---- hls4ml/converters/pytorch/core.py | 37 ++++++++++++++++++------ hls4ml/model/quantizers.py | 16 ++++++++++ 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index 68d9b56205..205638cc88 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,6 +1,8 @@ -from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler, convert_uaq_to_apfixed +from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, get_weights_data, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format -from hls4ml.model.types import FixedPrecisionType, BrevitasQuantizer +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType + @pytorch_handler('Conv1d', 'QuantConv1d') def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): @@ -18,7 +20,9 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') @@ -26,7 +30,9 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() - layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') else: @@ -81,7 +87,9 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') @@ -89,7 +97,9 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() - layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') else: diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 85af4994bd..a75e6b4bf2 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,5 +1,6 @@ -from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler, convert_uaq_to_apfixed -from hls4ml.model.types import FixedPrecisionType, BrevitasQuantizer +from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, get_weights_data, pytorch_handler +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType @pytorch_handler('Linear', 'QuantLinear') @@ -23,7 +24,9 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') @@ -31,7 +34,9 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() - layer['bias_quantizer'] = BrevitasQuantizer(width,FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True)) + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) else: layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') else: @@ -55,7 +60,19 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh','QuantReLU','QuantSigmoid','QuantTanh'] +activation_layers = [ + 'Softmax', + 'ReLU', + 'LeakyReLU', + 'Threshold', + 'ELU', + 'PReLU', + 'Sigmoid', + 'Tanh', + 'QuantReLU', + 'QuantSigmoid', + 'QuantTanh', +] @pytorch_handler(*activation_layers) @@ -71,9 +88,11 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['class_name'] = operation.split('Quant')[-1] layer['activation'] = layer['class_name'] bit_width = class_object.quant_act_bit_width() - ap_fixed_params = convert_uaq_to_apfixed(bit_width,class_object.quant_act_scale()) - layer['activation_quantizer'] = BrevitasQuantizer(bit_width,FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False)) - + ap_fixed_params = convert_uaq_to_apfixed(bit_width, class_object.quant_act_scale()) + layer['activation_quantizer'] = BrevitasQuantizer( + bit_width, FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False) + ) + if node.op == 'call_module': if layer['class_name'] == 'ReLU' or layer['class_name'] == 'Sigmoid': layer['class_name'] = 'Activation' @@ -109,7 +128,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['axis'] = node.kwargs['dim'] output_shape = input_shapes[0] - print (layer) + print(layer) return layer, output_shape diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index c857ef51ac..ffa363f4ad 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -158,3 +158,19 @@ def __call__(self, data): if hasattr(y, 'numpy'): y = y.numpy() return y + + +class BrevitasQuantizer(Quantizer): + """Wrapper around brevitas quantizers. Since we can get the already quantized tensors + directly from the brevitas QuantTensor objects, nothing needs to be done + + Args: + bits: bitwidth of the quantized tensor + hls_type: hls_type of the quantized tensor + """ + + def __init__(self, bits, hls_type): + super().__init__(bits, hls_type) + + def __call__(self, data): + return data From 9c1740f30eb9719612aacd8504b848f6374dca44 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 7 Jun 2024 14:50:51 -0400 Subject: [PATCH 13/47] reduce diff and update access to tensors to latest version --- hls4ml/converters/pytorch/convolution.py | 25 +++++++++++++++--------- hls4ml/converters/pytorch/core.py | 13 +++++++----- hls4ml/model/types.py | 1 - 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index 205638cc88..dca523d345 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,4 +1,4 @@ -from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, get_weights_data, pytorch_handler +from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format from hls4ml.model.quantizers import BrevitasQuantizer from hls4ml.model.types import FixedPrecisionType @@ -24,7 +24,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + layer['weight_data'] = class_object.weight.data.numpy() if class_object.is_bias_quant_enabled: width = int(class_object.quant_bias().bit_width) @@ -34,10 +34,13 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + layer['bias_data'] = class_object.bias.data.numpy() else: - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + layer['weight_data'] = class_object.weight.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None # Input info (layer['in_width'], layer['n_chan']) = parse_data_format( input_shapes[0], 'channels_first' @@ -91,7 +94,7 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + layer['weight_data'] = class_object.weight.data.numpy() if class_object.is_bias_quant_enabled: width = int(class_object.quant_bias().bit_width) @@ -101,10 +104,14 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + layer['bias_data'] = class_object.bias.data.numpy() else: - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + layer['weight_data'] = class_object.weight.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + # Input info (layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format( input_shapes[0], 'channels_first' diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index a75e6b4bf2..ab2c95116a 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,4 +1,4 @@ -from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, get_weights_data, pytorch_handler +from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler from hls4ml.model.quantizers import BrevitasQuantizer from hls4ml.model.types import FixedPrecisionType @@ -28,7 +28,7 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight') + layer['weight_data'] = class_object.weight.data.numpy() if class_object.is_bias_quant_enabled: width = int(class_object.quant_bias().bit_width) @@ -38,9 +38,13 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + layer['bias_data'] = class_object.bias.data.numpy() else: - layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias']) + layer['weight_data'] = class_object.weight.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None if class_object is not None: layer['n_in'] = class_object.in_features @@ -128,7 +132,6 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['axis'] = node.kwargs['dim'] output_shape = input_shapes[0] - print(layer) return layer, output_shape diff --git a/hls4ml/model/types.py b/hls4ml/model/types.py index 8da13ab4c1..fb5cde3863 100644 --- a/hls4ml/model/types.py +++ b/hls4ml/model/types.py @@ -9,7 +9,6 @@ import numpy as np - # region Precision types From c769fefcb8bab028a2580f5011dfd8d3ec9b4a2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:54:30 +0000 Subject: [PATCH 14/47] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/converters/pytorch/pooling.py | 11 ++++++-- hls4ml/converters/pytorch_to_hls.py | 42 ++++++++++++++++------------ 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index fecbbb0d19..c9cbea8ecf 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -1,7 +1,14 @@ from hls4ml.converters.pytorch_to_hls import pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format -pooling_layers = ['MaxPool1d', 'MaxPool2d', 'AvgPool1d', 'AvgPool2d', 'QuantMaxPool1d', 'QuantMaxPool2d'] #TODO add support for special quantized average pool layers +pooling_layers = [ + 'MaxPool1d', + 'MaxPool2d', + 'AvgPool1d', + 'AvgPool2d', + 'QuantMaxPool1d', + 'QuantMaxPool2d', +] # TODO add support for special quantized average pool layers @pytorch_handler(*pooling_layers) @@ -92,7 +99,7 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, layer['stride_width'] = node.args[-1][0] else: layer['stride_height'] = node.args[-1] - layer['stride_width'] = node.args[-1] + layer['stride_width'] = node.args[-1] elif type(node.kwargs['stride']) is tuple: layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1] diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 87ae6149bb..c235eb771e 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,8 +1,10 @@ -import torch import math +import torch + from hls4ml.model import ModelGraph + class CustomFXTracer(torch.fx.Tracer): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: @@ -10,9 +12,11 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX """ return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn") or m.__module__.startswith("brevitas.nn")) - and not isinstance(m, torch.nn.Sequential) - ) + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + or m.__module__.startswith("brevitas.nn") + ) and not isinstance(m, torch.nn.Sequential) + class PyTorchModelReader: """ @@ -68,21 +72,23 @@ def get_weights_data(data_reader, layer_name, var_name): else: return (*data,) + def convert_uaq_to_apfixed(bitwidth, scale_factor): - """ - parameters: - bitwidth: int - scale_factor: float - zero_point: float - - return: - int_bitwidth: int - fract_bitwidth: int - """ - fract_bitwidth = - math.log2(scale_factor) - int_bitwidth = bitwidth - fract_bitwidth - - return (fract_bitwidth, int_bitwidth) + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = -math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (fract_bitwidth, int_bitwidth) + # ----------------------Layer handling--------------------- # layer_handlers = {} From 0bb09f068c7abd4b3f2b56be517944d7e90a9504 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Tue, 9 Jul 2024 13:15:42 -0400 Subject: [PATCH 15/47] add brevitas to the requirements for tests --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 1911f3b328..c34d32bb89 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,7 @@ profiling = sr = sympy testing = + brevitas pytest pytest-cov pytest-randomly From cda36b671795b4f5121cab145507fa37b51151fb Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 22 Jul 2024 16:33:14 -0400 Subject: [PATCH 16/47] adjust required precision in brevitas pytests --- test/pytest/test_brevitas_parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index c1aeb14bf5..b7b06e6296 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -59,7 +59,7 @@ def test_quantlinear(backend, io_type): hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @@ -134,4 +134,4 @@ def test_quantconv2d(backend, io_type): else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) From ef380ad6a2044ed9317244a1500b629dda330fae Mon Sep 17 00:00:00 2001 From: Rian Flynn Date: Mon, 22 Jul 2024 16:31:02 -0400 Subject: [PATCH 17/47] Add conv1d tests, fix output dir and tolerances --- test/pytest/test_brevitas_parsing.py | 81 ++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index c1aeb14bf5..e1d3ff82b2 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -24,6 +24,16 @@ def forward(self, x): out = self.relu1(self.conv1(x)) return out +class QuantModelConv1d(Module): + def __init__(self): + super().__init__() + self.conv1 = qnn.QuantConv1d(3, 6, 4, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.relu1 = nn.ReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + class QuantModelLinear(Module): def __init__(self): @@ -61,17 +71,80 @@ def test_quantlinear(backend, io_type): np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_quantconv1d(backend, io_type): + model = QuantModelConv1d() + + n_in = 3 + n_out = 6 + size_in = 5 + + x = torch.randn(1, n_in, size_in) + + pytorch_prediction = model(x).detach().numpy() + if io_type == 'io_stream': + x = np.ascontiguousarray(x.permute(0, 2, 1)) + config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) + else: + config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) + + output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') + + from hls4ml.converters.pytorch_to_hls import CustomFXTracer + + tracer = CustomFXTracer() + traced_model = tracer.trace(model) + nNodes = 0 + convNode = None + for _node in traced_model.nodes: + nNodes += 1 + if nNodes == 2: + convNode = _node + + children = {c[0]: c[1] for c in model.named_children()} + class_object_conv = children[convNode.target] + + out_width = int( + ( + size_in + + 2 * class_object_conv.padding[0] + - class_object_conv.dilation[0] * (class_object_conv.kernel_size[0] - 1) + - 1 + ) + / class_object_conv.stride[0] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + + hls_model = convert_from_pytorch_model( + model, + (None, n_in, size_in), + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type + ) + hls_model.compile() + + if io_type == 'io_stream': + hls_prediction = np.transpose(np.reshape(hls_model.predict(x), (1, out_width, n_out)), (0, 2, 1)) + else: + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_quantconv2d(backend, io_type): model = QuantModelConv2d() + n_in = 3 n_out = 6 size_in_width = 5 size_in_height = 6 - x = torch.randn(1, 3, 6, 5) + x = torch.randn(1, n_in, size_in_height, size_in_width) pytorch_prediction = model(x).detach().numpy() config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) @@ -81,7 +154,7 @@ def test_quantconv2d(backend, io_type): else: config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) - output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}') from hls4ml.converters.pytorch_to_hls import CustomFXTracer @@ -121,7 +194,7 @@ def test_quantconv2d(backend, io_type): hls_model = convert_from_pytorch_model( model, - (None, 3, 6, 5), + (None, n_in, size_in_height, size_in_width), hls_config=config, output_dir=output_dir, backend=backend, @@ -134,4 +207,4 @@ def test_quantconv2d(backend, io_type): else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2) From dffa37997030ee76b09a2a02a2a9ddde828614d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 20:54:15 +0000 Subject: [PATCH 18/47] [pre-commit.ci] auto fixes from pre-commit hooks --- test/pytest/test_brevitas_parsing.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index e1d3ff82b2..597884f0f8 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -24,6 +24,7 @@ def forward(self, x): out = self.relu1(self.conv1(x)) return out + class QuantModelConv1d(Module): def __init__(self): super().__init__() @@ -71,6 +72,7 @@ def test_quantlinear(backend, io_type): np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + @pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_quantconv1d(backend, io_type): @@ -88,7 +90,7 @@ def test_quantconv1d(backend, io_type): config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) else: config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) - + output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') from hls4ml.converters.pytorch_to_hls import CustomFXTracer @@ -117,12 +119,7 @@ def test_quantconv1d(backend, io_type): ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html hls_model = convert_from_pytorch_model( - model, - (None, n_in, size_in), - hls_config=config, - output_dir=output_dir, - backend=backend, - io_type=io_type + model, (None, n_in, size_in), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type ) hls_model.compile() @@ -130,7 +127,7 @@ def test_quantconv1d(backend, io_type): hls_prediction = np.transpose(np.reshape(hls_model.predict(x), (1, out_width, n_out)), (0, 2, 1)) else: hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) From 399613e821ec030dafb26ec6ddd062a88a15962b Mon Sep 17 00:00:00 2001 From: Rian Flynn Date: Thu, 25 Jul 2024 16:50:22 -0400 Subject: [PATCH 19/47] Test QuantMaxPool and ignore QuantDropout --- hls4ml/converters/pytorch/pooling.py | 4 +- hls4ml/converters/pytorch_to_hls.py | 2 +- test/pytest/test_brevitas_parsing.py | 59 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index c9cbea8ecf..82cefc5428 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -17,9 +17,9 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, layer = {} - if operation == 'MaxPool1d': + if 'MaxPool1d' in operation: layer['class_name'] = 'MaxPooling1D' - if operation == 'MaxPool2d': + if 'MaxPool2d' in operation: layer['class_name'] = 'MaxPooling2D' if operation == 'AvgPool1d': layer['class_name'] = 'AveragePooling1D' diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index a5b86d5282..682844cea8 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -168,7 +168,7 @@ def pytorch_to_hls(config): tracer = CustomFXTracer() traced_model = tracer.trace(model) # Define layers to skip for conversion to HLS - skip_layers = ['Dropout', 'Sequential'] + skip_layers = ['Dropout', 'QuantDropout', 'Sequential'] # All supported layers supported_layers = get_supported_pytorch_layers() + skip_layers diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 597884f0f8..4364e6db19 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -205,3 +205,62 @@ def test_quantconv2d(backend, io_type): hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2) + + +class QuantMaxPool1d(Module): + def __init__(self): + super().__init__() + self.pool = qnn.QuantMaxPool1d(2) + + def forward(self, x): + return self.pool(x) + + +class QuantMaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = qnn.QuantMaxPool2d(2) + + def forward(self, x): + return self.pool(x) + + +@pytest.mark.parametrize('pooling', [QuantMaxPool1d, QuantMaxPool2d]) +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +def test_pooling(pooling, backend): + model = pooling() + + assert '1d' in pooling.__name__ or '2d' in pooling.__name__ + + if '2d' in pooling.__name__: + n_in = 2 + size_in_height = 15 + size_in_width = 18 + else: + n_in = 2 + size_in_width = 121 + size_in_height = 0 + + input_shape = (1, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (1, n_in, size_in_width) + input_shape_forHLS = ( + (None, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (None, n_in, size_in_width) + ) + x = torch.randn(*input_shape) + + pytorch_prediction = model(x).tensor.detach().numpy() + + config = config_from_pytorch_model(model) + output_dir = str(test_root_path / f'hls4mlprj_brevitas_{pooling.__name__}_{backend}') + + hls_model = convert_from_pytorch_model( + model, + input_shape_forHLS, + hls_config=config, + output_dir=output_dir, + backend=backend, + ) + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2) From d13bf521ceeb8430a504bd7585aafc7e9ac0f582 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:35:41 +0000 Subject: [PATCH 20/47] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/model/quantizers.py | 3 ++- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index cc38c96864..e7b32894aa 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -182,6 +182,7 @@ def __init__(self, bits, hls_type): def __call__(self, data): return data + class QuantNodeQuantizer(Quantizer): """ This implements a quantizer for a FixedPrecisionType with width==integer @@ -274,4 +275,4 @@ def _resolve_rounding_mode(mode): elif mode == RoundingMode.TRN: return np.floor else: - raise ValueError(f'Rounding mode {mode} not supported.') \ No newline at end of file + raise ValueError(f'Rounding mode {mode} not supported.') diff --git a/setup.cfg b/setup.cfg index 8498a67081..2989c378aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,8 +54,8 @@ profiling = sr = sympy testing = - brevitas HGQ~=0.2.0 + brevitas pytest pytest-cov pytest-randomly From 1f17845396266ff123d930a27a584b5dcae6b72e Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 10 Jan 2025 12:45:15 -0500 Subject: [PATCH 21/47] restore accidental change --- hls4ml/converters/pytorch/pooling.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 957f84c8da..882a7f7aed 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -93,14 +93,7 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, padding = [class_object.padding, class_object.padding] else: - if node.kwargs['stride'] is None: - if type(node.args[-1]) is tuple: - layer['stride_height'] = node.args[-1][0] - layer['stride_width'] = node.args[-1][0] - else: - layer['stride_height'] = node.args[-1] - layer['stride_width'] = node.args[-1] - elif type(node.kwargs['stride']) is tuple: + if type(node.kwargs['stride']) is tuple: layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1] else: From 7e2fdf7c927f69319045e7aa6de7d6fc2d34697f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:25:25 +0000 Subject: [PATCH 22/47] [pre-commit.ci] auto fixes from pre-commit hooks --- hls4ml/converters/pytorch/core.py | 5 ++++- hls4ml/converters/pytorch_to_hls.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index c164a7574a..b5bd126181 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,8 +1,10 @@ import numpy as np + from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler from hls4ml.model.quantizers import BrevitasQuantizer from hls4ml.model.types import FixedPrecisionType + @pytorch_handler('Constant') def parse_constant_layer(operation, layer_name, node): assert 'Constant' in operation @@ -18,7 +20,8 @@ def parse_constant_layer(operation, layer_name, node): output_shape = constant.shape return layer, output_shape - + + @pytorch_handler('Linear', 'QuantLinear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 80599eae53..b95c0b4a2b 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,4 +1,5 @@ import math + import numpy as np import torch @@ -190,7 +191,7 @@ def parse_pytorch_model(config, verbose=True): layer_counter = 0 n_inputs = 0 - + # check for constant nodes merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax'] i = 0 # count number of consts and use it in the name @@ -207,7 +208,7 @@ def parse_pytorch_model(config, verbose=True): i += 1 traced_model.graph.lint() - + for node in traced_model.nodes: if node.op == 'call_module': # modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x', From 10d77b62e242732bad37802c5cffa053bd1dd0b8 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 10 Jan 2025 14:24:50 -0500 Subject: [PATCH 23/47] update pytests for interface changes and fix merge errors --- hls4ml/converters/pytorch_to_hls.py | 8 +++----- test/pytest/test_brevitas_parsing.py | 28 +++++++++++++++------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index b95c0b4a2b..b0de5a4198 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -200,14 +200,12 @@ def parse_pytorch_model(config, verbose=True): for arg in node.args: if np.isscalar(arg): # add an input node with the constant value - new_node = traced_model.graph.placeholder( - name='const_' + str(i), type_expr=torch.Tensor, default_value=arg - ) + new_node = traced_model.placeholder(name='const_' + str(i), type_expr=torch.Tensor, default_value=arg) node.prepend(new_node) node.update_arg(1, new_node) i += 1 - traced_model.graph.lint() + traced_model.lint() for node in traced_model.nodes: if node.op == 'call_module': @@ -258,7 +256,7 @@ def parse_pytorch_model(config, verbose=True): input_shapes = [output_shapes[str(node.args[0])]] # if a 'getitem' is the input to a node, step back in the graph to find the real source of the input elif "getitem" in node.args[0].name: - for tmp_node in traced_model.graph.nodes: + for tmp_node in traced_model.nodes: if tmp_node.name == node.args[0].name: if "getitem" in tmp_node.args[0].name: raise Exception('Nested getitem calles not resolved at the moment.') diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 9f2d83fbec..83ef370dc4 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -55,12 +55,11 @@ def test_quantlinear(backend, io_type): x = torch.tensor([1.0, 2.0, 3.0, 4.0]) pytorch_prediction = model(x).detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, input_shape=(None, 4)) output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') hls_model = convert_from_pytorch_model( model, - (None, 4), hls_config=config, output_dir=output_dir, backend=backend, @@ -87,9 +86,13 @@ def test_quantconv1d(backend, io_type): pytorch_prediction = model(x).detach().numpy() if io_type == 'io_stream': x = np.ascontiguousarray(x.permute(0, 2, 1)) - config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) + config = config_from_pytorch_model( + model, (None, n_in, size_in), channels_last_conversion="internal", transpose_outputs=False + ) else: - config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) + config = config_from_pytorch_model( + model, (None, n_in, size_in), channels_last_conversion="full", transpose_outputs=True + ) output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') @@ -118,9 +121,7 @@ def test_quantconv1d(backend, io_type): + 1 ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - hls_model = convert_from_pytorch_model( - model, (None, n_in, size_in), hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type - ) + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) hls_model.compile() if io_type == 'io_stream': @@ -144,12 +145,15 @@ def test_quantconv2d(backend, io_type): x = torch.randn(1, n_in, size_in_height, size_in_width) pytorch_prediction = model(x).detach().numpy() - config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) if io_type == 'io_stream': x = np.ascontiguousarray(x.permute(0, 2, 3, 1)) - config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False) + config = config_from_pytorch_model( + model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="internal", transpose_outputs=False + ) else: - config = config_from_pytorch_model(model, inputs_channel_last=False, transpose_outputs=True) + config = config_from_pytorch_model( + model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="full", transpose_outputs=True + ) output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}') @@ -191,7 +195,6 @@ def test_quantconv2d(backend, io_type): hls_model = convert_from_pytorch_model( model, - (None, n_in, size_in_height, size_in_width), hls_config=config, output_dir=output_dir, backend=backend, @@ -249,12 +252,11 @@ def test_pooling(pooling, backend): pytorch_prediction = model(x).tensor.detach().numpy() - config = config_from_pytorch_model(model) + config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True) output_dir = str(test_root_path / f'hls4mlprj_brevitas_{pooling.__name__}_{backend}') hls_model = convert_from_pytorch_model( model, - input_shape_forHLS, hls_config=config, output_dir=output_dir, backend=backend, From 0cce3ef2292bf79d9471560d51b2e4cc25c34ac3 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 17 Jan 2025 12:24:55 -0500 Subject: [PATCH 24/47] update to work with brevitas 0.11.0 --- hls4ml/converters/pytorch/convolution.py | 8 ++-- hls4ml/converters/pytorch/core.py | 15 +++--- test/pytest/test_brevitas_parsing.py | 58 ------------------------ 3 files changed, 12 insertions(+), 69 deletions(-) diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index a02225fbab..22a4cae6f7 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -16,7 +16,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c layer['data_format'] = 'channels_first' # Pytorch default (can't change) if "Quant" in operation: - if class_object.is_weight_quant_enabled: + if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() @@ -26,7 +26,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c else: layer['weight_data'] = class_object.weight.data.numpy() - if class_object.is_bias_quant_enabled: + if class_object.bias_quant.is_quant_enabled: width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() @@ -81,7 +81,7 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['data_format'] = 'channels_first' # Pytorch default (can't change) if "Quant" in operation: - if class_object.is_weight_quant_enabled: + if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() @@ -91,7 +91,7 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c else: layer['weight_data'] = class_object.weight.data.numpy() - if class_object.is_bias_quant_enabled: + if class_object.bias_quant.is_quant_enabled: width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index b5bd126181..8604a9db87 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -39,7 +39,7 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c layer['bias_data'] = None if "Quant" in operation: - if class_object.is_weight_quant_enabled: + if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) layer['weight_data'] = class_object.quant_weight().detach().value.numpy() @@ -49,7 +49,7 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c else: layer['weight_data'] = class_object.weight.data.numpy() - if class_object.is_bias_quant_enabled: + if class_object.bias_quant.is_quant_enabled: width = int(class_object.quant_bias().bit_width) ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) layer['bias_data'] = class_object.quant_bias().detach().value.numpy() @@ -110,11 +110,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod if "Quant" in operation: layer['class_name'] = operation.split('Quant')[-1] layer['activation'] = layer['class_name'] - bit_width = class_object.quant_act_bit_width() - ap_fixed_params = convert_uaq_to_apfixed(bit_width, class_object.quant_act_scale()) - layer['activation_quantizer'] = BrevitasQuantizer( - bit_width, FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False) - ) + if class_object.act_quant.is_quant_enabled: + bit_width = int(class_object.act_quant.bit_width()) + ap_fixed_params = convert_uaq_to_apfixed(bit_width, float(class_object.act_quant.scale())) + layer['activation_quantizer'] = BrevitasQuantizer( + bit_width, FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False) + ) if node.op == 'call_module': if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']: diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 83ef370dc4..2d4a668fb9 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -208,61 +208,3 @@ def test_quantconv2d(backend, io_type): hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) - - -class QuantMaxPool1d(Module): - def __init__(self): - super().__init__() - self.pool = qnn.QuantMaxPool1d(2) - - def forward(self, x): - return self.pool(x) - - -class QuantMaxPool2d(Module): - def __init__(self): - super().__init__() - self.pool = qnn.QuantMaxPool2d(2) - - def forward(self, x): - return self.pool(x) - - -@pytest.mark.parametrize('pooling', [QuantMaxPool1d, QuantMaxPool2d]) -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) -def test_pooling(pooling, backend): - model = pooling() - - assert '1d' in pooling.__name__ or '2d' in pooling.__name__ - - if '2d' in pooling.__name__: - n_in = 2 - size_in_height = 15 - size_in_width = 18 - else: - n_in = 2 - size_in_width = 121 - size_in_height = 0 - - input_shape = (1, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (1, n_in, size_in_width) - input_shape_forHLS = ( - (None, n_in, size_in_height, size_in_width) if '2d' in pooling.__name__ else (None, n_in, size_in_width) - ) - x = torch.randn(*input_shape) - - pytorch_prediction = model(x).tensor.detach().numpy() - - config = config_from_pytorch_model(model, input_shape_forHLS, transpose_outputs=True) - output_dir = str(test_root_path / f'hls4mlprj_brevitas_{pooling.__name__}_{backend}') - - hls_model = convert_from_pytorch_model( - model, - hls_config=config, - output_dir=output_dir, - backend=backend, - ) - hls_model.compile() - - hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) - - np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2) From 7543fabb4ffb3e269184797e9c4a3fb7466c3167 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 19 Feb 2025 14:34:51 -0500 Subject: [PATCH 25/47] add handling of input/output quantization for layers. Non Pow-2 weights are broken. --- hls4ml/converters/pytorch/convolution.py | 72 ++++++-- hls4ml/converters/pytorch/core.py | 42 +++-- hls4ml/converters/pytorch_to_hls.py | 35 ++++ hls4ml/model/graph.py | 2 +- hls4ml/model/optimizer/__init__.py | 9 + .../optimizer/passes/brevitas_optimizer.py | 157 ++++++++++++++++++ hls4ml/model/optimizer/passes/quant_opt.py | 43 ++++- test/pytest/test_brevitas_parsing.py | 30 +++- 8 files changed, 350 insertions(+), 40 deletions(-) create mode 100644 hls4ml/model/optimizer/passes/brevitas_optimizer.py diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index 22a4cae6f7..a70df8cf39 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,4 +1,6 @@ -from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler +import numpy as np + +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format from hls4ml.model.quantizers import BrevitasQuantizer from hls4ml.model.types import FixedPrecisionType @@ -18,11 +20,25 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c if "Quant" in operation: if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) - ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) - layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + # for non-power-of-2 scales, instead take the unquantized inputs and set parameters + # so an ApplyAlpha node will be added to the model. Currently broken :/ + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') + # layer['weight_data'] = class_object.quant_weight().detach().value.numpy() else: layer['weight_data'] = class_object.weight.data.numpy() @@ -34,7 +50,15 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = class_object.bias.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + else: layer['weight_data'] = class_object.weight.data.numpy() if class_object.bias is not None: @@ -83,11 +107,25 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c if "Quant" in operation: if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) - ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) - layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + # for non-power-of-2 scales, instead take the unquantized inputs and set parameters so an + # ApplyAlpha node will be added to the model. Currently broken :/ + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') + # layer['weight_data'] = class_object.quant_weight().detach().value.numpy() else: layer['weight_data'] = class_object.weight.data.numpy() @@ -99,7 +137,15 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = class_object.bias.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + else: layer['weight_data'] = class_object.weight.data.numpy() if class_object.bias is not None: diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 8604a9db87..2b08ff7789 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,6 +1,6 @@ import numpy as np -from hls4ml.converters.pytorch_to_hls import convert_uaq_to_apfixed, pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler from hls4ml.model.quantizers import BrevitasQuantizer from hls4ml.model.types import FixedPrecisionType @@ -41,11 +41,25 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c if "Quant" in operation: if class_object.weight_quant.is_quant_enabled: width = int(class_object.quant_weight().bit_width) - ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) - layer['weight_data'] = class_object.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + # for non-power-of-2 scales, instead take the quantized inputs and set parameters so an ApplyAlpha node + # will be added to the model. Currently not working :/ + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') + # layer['weight_data'] = class_object.quant_weight().int().detach().numpy() else: layer['weight_data'] = class_object.weight.data.numpy() @@ -57,7 +71,15 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) else: - layer['bias_data'] = class_object.bias.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + else: layer['weight_data'] = class_object.weight.data.numpy() if class_object.bias is not None: @@ -111,11 +133,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['class_name'] = operation.split('Quant')[-1] layer['activation'] = layer['class_name'] if class_object.act_quant.is_quant_enabled: - bit_width = int(class_object.act_quant.bit_width()) - ap_fixed_params = convert_uaq_to_apfixed(bit_width, float(class_object.act_quant.scale())) - layer['activation_quantizer'] = BrevitasQuantizer( - bit_width, FixedPrecisionType(width=bit_width, integer=ap_fixed_params[1], signed=False) - ) + layer = addQuantizationParameters(layer, class_object.act_quant, 'output', act=True) if node.op == 'call_module': if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']: diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index b0de5a4198..7f8e368488 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -91,6 +91,41 @@ def convert_uaq_to_apfixed(bitwidth, scale_factor): return (fract_bitwidth, int_bitwidth) +def addQuantizationParameters(layer, quant_object, quant_type, act=False): + if not act: + print(quant_object.bit_width) + bit_width = int(quant_object.bit_width) + # signed = quant_object.is_signed + signed = quant_object.signed + scale = float(quant_object.scale) + zeropoint = float(quant_object.zero_point) + if signed: + narrow = True + else: + narrow = False + rounding_mode = 'ROUND' + layer['convert_from_brevitas'] = True + else: + bit_width = int(quant_object.bit_width()) + signed = quant_object.is_signed + scale = float(quant_object.scale()) + zeropoint = float(quant_object.zero_point()) + narrow = quant_object.is_narrow_range + rounding_mode = quant_object.rounding_mode + layer['convert_io_from_brevitas'] = True + print(scale) + + layer[f'{quant_type}_quantization'] = { + 'bit_width': bit_width, + 'signed': signed, + 'scale': scale, + 'zeropoint': zeropoint, + 'narrow': narrow, + 'rounding_mode': rounding_mode, + } + return layer + + # ----------------------Layer handling--------------------- # layer_handlers = {} diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 520f96ba5f..ae10d060e0 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -494,7 +494,7 @@ def insert_node(self, node, before=None, input_idx=0): next_nodes.append(x) if before is None: - next_node = next((x for x in self.graph.values() if x.inputs[0] in prev_node.outputs), None) + next_node = next((x for x in self.graph.values() if x.inputs and x.inputs[0] in prev_node.outputs), None) else: if before not in next_nodes: raise Exception( diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 7e9325ccd0..cb206ffc84 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -30,6 +30,14 @@ del module_path del optimizers +register_flow( + 'parse_brevitas', + [ + 'brevitas_input_output_optimizer', + 'brevitas_factorize_alpha', + ], +) + register_flow( 'parse_qonnx', [ @@ -53,6 +61,7 @@ 'conv_to_conv_x_d', 'conv_to_depthwise_conv_x_d', ], + requires=['parse_brevitas'], ) register_flow( diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py new file mode 100644 index 0000000000..f7ac4ab9ab --- /dev/null +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -0,0 +1,157 @@ +# Conversion of model from channels_first to channels_last data format +# Based on https://github.com/fastmachinelearning/qonnx/blob/ +# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py +import math + +import numpy as np + +from hls4ml.model.layers import ApplyAlpha +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.optimizer.passes.quant_opt import _calculate_precision_quantizer +from hls4ml.model.types import NamedType, find_minimum_width + + +class BrevitasInputOutputOptimizer(OptimizerPass): + '''Takes nodes parsed from brevitas and inserts Quant nodes into the model if necessary''' + + def match(self, node): + needs_conversion = False + if 'convert_io_from_brevitas' in node.attributes.keys(): + needs_conversion = node.attributes['convert_io_from_brevitas'] and ( + 'output_quantization' in node.attributes.keys() or 'input_quantization' in node.attributes.keys() + ) + + return needs_conversion + + def transform(self, model, node): + + # See if Quant layer needs to be added for the output + if 'output_quantization' in node.attributes.keys(): + + attributes = {} + + input = node.name + # Other attributes + attributes['narrow'] = node.attributes['output_quantization']['narrow'] + attributes['rounding_mode'] = node.attributes['output_quantization']['rounding_mode'] + attributes['signed'] = node.attributes['output_quantization']['signed'] + attributes['bitwidth'] = node.attributes['output_quantization']['bit_width'] + attributes['zeropt'] = node.attributes['output_quantization']['zeropoint'] + attributes['scale'] = np.array([node.attributes['output_quantization']['scale']]) + + quant_node = model.make_node('Quant', f'quant_output_for_{node.get_attr("name")}', attributes, [input]) + quant_node.set_attr('name', f'quant_output_for_{node.get_attr("name")}') + + model.insert_node(quant_node) + + node.attributes['convert_io_from_brevitas'] = False + + elif 'input_quantization' in node.attributes.keys(): + + attributes = {} + + input = node.inputs[0] + # Other attributes + attributes['narrow'] = node.attributes['input_quantization']['narrow'] + attributes['rounding_mode'] = node.attributes['input_quantization']['rounding_mode'] + attributes['signed'] = node.attributes['input_quantization']['signed'] + attributes['bitwidth'] = node.attributes['input_quantization']['bit_width'] + attributes['zeropt'] = node.attributes['input_quantization']['zeropoint'] + attributes['scale'] = np.array([node.attributes['input_quantization']['scale']]) + + quant_node = model.make_node('Quant', f'quant_input_for_{node.get_attr("name")}', attributes, [input]) + quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}') + + model.insert_node(quant_node) + + node.attributes['convert_io_from_brevitas'] = False + return True + + +class BrevitasFactorizeAlpha(OptimizerPass): + '''OptimizerPass for extracting alpha "scale" from Brevitas quantized layer. + The weights of the Quant{Dense, Conv} layer are scaled to the common data type, + and an 'ApplyAlpha' layer is inserted to reapply the scale. + ''' + + def match(self, node): + q_layer = node.class_name in ['Dense', 'QConv1D', 'Conv2D'] + + has_w_alpha = 'weight_quantization' in node.attributes.keys() + has_b_alpha = 'bias_quantization' in node.attributes.keys() + + needs_conversion = False + if 'convert_from_brevitas' in node.attributes.keys(): + needs_conversion = node.attributes['convert_from_brevitas'] + + is_match = q_layer and needs_conversion and (has_w_alpha or has_b_alpha) + return is_match + + def transform(self, model, node): + # The quantizer has to be applied to set the scale attribute + # This must be applied to the _unquantized_ weights to obtain the correct scale + if node.attributes['convert_from_brevitas'] is False: + return False + scale = np.full(node.weights['weight'].data.shape, [node.attributes['weight_quantization']['scale']]) + + # find number of bits to represent unscaled weight tensor (should be the full bit width, but better be sure) + # and set precision for weight variable + int_bits = find_minimum_width(node.weights['weight'].data, signed=True) + + unscale_precision, _ = _calculate_precision_quantizer(int_bits, int_bits, True, True, 'FLOOR') + node.weights['weight'].type = NamedType(node.weights['weight'].name + '_t', unscale_precision) + res_precision, _ = _calculate_precision_quantizer(int_bits * 2, int_bits, True, True, 'FLOOR') + node.types['accum_t'] = NamedType(node.name + '_accum_t', res_precision) + node.types['result_t'].type = res_precision + + # Move the biases from the Dense layer to the ApplyAlpha layer + bias = node.weights['bias'].data + node.weights['bias'].data = np.zeros(bias.shape) + + # insert a Batch Normalization layer to apply the alpha scale + if 'Linear' in node.class_name: + n_in = node.get_attr('n_out') + elif 'Conv' in node.class_name: + n_in = node.get_attr('out_width') * node.get_attr('out_height', 1) * node.get_attr('n_filt') + else: + n_in = node.get_attr('n_out') + + # the name of the new ApplyAlpha node + alpha_name = node.get_attr('name') + '_alpha' + + # make the precision auto + alpha_precision = {'Precision': 'auto'} + model.config.set_name_config(alpha_name, alpha_precision) + model.config.parse_name_config(alpha_name, alpha_precision) + + # This part is very stupid, since this basically just results in the scale being represented at 2*bith width, + # otherwise it just uses full system float precision. Needs work + fractional_part, integer_part = math.modf(node.attributes['weight_quantization']['scale']) + if integer_part > 0: + int_bits = math.ceil(math.log2(integer_part)) + 1 + else: + int_bits = 0 + frac_bits = math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) + scale_precision, scale_quantizer = _calculate_precision_quantizer( + int_bits + frac_bits, int_bits, True, False, 'FLOOR' + ) + + attrs = { + 'name': alpha_name, + 'class_name': 'Alpha', + 'inputs': node.outputs, + 'n_in': n_in, + 'n_filt': node.get_attr('n_filt', -1), + 'reuse_factor': node.get_attr('reuse_factor'), + 'scale_data': scale, + 'scale_quantizer': scale_quantizer, + 'scale_precision': scale_precision, + 'bias_data': bias, + 'bias_quantizer': None, + 'bias_precision': None, + 'trace': node.get_attr('trace', False), + } + alpha_layer = model.make_node(ApplyAlpha, node.name + '_alpha', attrs, node.outputs) + model.insert_node(alpha_layer) + node.attributes['convert_from_brevitas'] = False + return True diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 04d5393748..ffc3980e33 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -100,7 +100,6 @@ def match(self, node): scale = node.get_attr('scale') bias = node.get_attr('zeropt') is_match = is_match and (bias == np.zeros_like(bias)).all() - # check if scale is ones-like or a power of two scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: @@ -232,7 +231,6 @@ def transform(self, model, node): narrow = node.get_attr('narrow') signed = node.get_attr('signed') bitwidth = node.get_attr('bitwidth') - precision, quantizer = _calculate_precision_quantizer(bitwidth, bitwidth, signed, narrow, rounding_mode) activation_attributes = {'activation': 'linear', 'quantizer': quantizer} @@ -245,8 +243,14 @@ def transform(self, model, node): act_name = f'{node.name}_act' model.config.set_name_config(act_name, act_config) model.config.parse_name_config(act_name, act_config) - - new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs]) + if 'global_out' in node.outputs: + new_node = model.make_node( + Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs] + ) + else: + new_node = model.make_node( + Activation, act_name, activation_attributes, [node.inputs[0]], [x + '_act' for x in node.outputs] + ) model.replace_node(node, new_node) # but now add the ApplyAlhpas before and after @@ -268,19 +272,48 @@ def transform(self, model, node): rescale_name = f'{node.name}_rescale' model.config.set_name_config(rescale_name, rescale_config) model.config.parse_name_config(rescale_name, rescale_config) - firstscale = 1 / scale + + # need to adjust data type to account for the fact that the inverse scale needs mostly integer bits + fractional_part, integer_part = math.modf(firstscale) + int_bits = math.ceil(math.log2(integer_part)) + 1 + frac_bits = min( + node.get_attr('bitwidth'), + math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) + 1, + ) + scale_precision, scale_quantizer = _calculate_precision_quantizer( + int_bits + frac_bits, int_bits, False, False, 'FLOOR' + ) + firstbias = bias attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape) + attributes_scale['scale_quantizer'] = scale_quantizer + attributes_scale['scale_precision'] = scale_precision scale_node = model.make_node(ApplyAlpha, scale_name, attributes_scale, [node.inputs[0]]) + scale_node.types['result_t'].precision = scale_precision model.insert_node(scale_node) + fractional_part, integer_part = math.modf(scale) + if integer_part > 0: + int_bits = math.ceil(math.log2(integer_part)) + 1 + else: + int_bits = 0 + frac_bits = min( + node.get_attr('bitwidth') * 2, + math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) + 1, + ) + scale_precision, scale_quantizer = _calculate_precision_quantizer( + int_bits + frac_bits, int_bits, False, False, 'FLOOR' + ) + rescale = scale rebias = -bias * scale attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + attributes_rescale['scale_quantizer'] = scale_quantizer + attributes_rescale['scale_precision'] = scale_precision rescale_node = model.make_node(ApplyAlpha, rescale_name, attributes_rescale, [new_node.outputs[0]]) model.insert_node(rescale_node) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 2d4a668fb9..f39cc3a6af 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from brevitas.quant import Int8WeightPerTensorFixedPoint +from brevitas.quant import Int8ActPerTensorFixedPoint, Int8WeightPerTensorFixedPoint, Int8WeightPerTensorFloat from torch import nn from torch.nn import Module @@ -13,6 +13,12 @@ test_root_path = Path(__file__).parent +quants = { + 'Int8WeightPerTensorFixedPoint': Int8WeightPerTensorFixedPoint, + 'Int8ActPerTensorFixedPoint': Int8ActPerTensorFixedPoint, + 'Int8WeightPerTensorFloat': Int8WeightPerTensorFloat, +} + class QuantModelConv2d(Module): def __init__(self): @@ -37,9 +43,12 @@ def forward(self, x): class QuantModelLinear(Module): - def __init__(self): + def __init__(self, weight_quant, input_quant): super().__init__() - self.conv1 = qnn.QuantLinear(4, 4, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + # self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=quants[weight_quant], input_quant=quants[input_quant]) + self.conv1 = qnn.QuantLinear( + 4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint, input_quant=Int8ActPerTensorFixedPoint + ) self.relu1 = qnn.QuantReLU() def forward(self, x): @@ -47,16 +56,19 @@ def forward(self, x): return out -@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) -def test_quantlinear(backend, io_type): - model = QuantModelLinear() - - x = torch.tensor([1.0, 2.0, 3.0, 4.0]) +@pytest.mark.parametrize('weight_quant', ['Int8WeightPerTensorFixedPoint']) +@pytest.mark.parametrize('io_quant', ['Int8ActPerTensorFixedPoint']) +def test_quantlinear(backend, io_type, weight_quant, io_quant): + # def test_quantlinear(backend, io_type): + model = QuantModelLinear(weight_quant, io_quant) + x = torch.rand(1, 4) pytorch_prediction = model(x).detach().numpy() config = config_from_pytorch_model(model, input_shape=(None, 4)) - output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + # output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') + output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}_{weight_quant}_{io_quant}') hls_model = convert_from_pytorch_model( model, From be87dd4caa56a7ac89cabf0eac8550ba434b748a Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Thu, 20 Feb 2025 15:41:01 -0500 Subject: [PATCH 26/47] hack around lazy imports --- hls4ml/converters/pytorch_to_hls.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index aa168b3754..bcb25661b5 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -2,19 +2,27 @@ import numpy as np -# this import conficts with our new lazy imports. Not sure how to handle this otherwise yet -import torch - from hls4ml.model import ModelGraph from hls4ml.utils.dependency import requires +# this import conficts with our new lazy imports. Not sure how to handle this otherwise yet + -class CustomFXTracer(torch.fx.Tracer): +class CustomFXTracer: - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + def __new__(cls, *args, **kwargs): + import torch.fx as fx + + new_type = type('CustomFXTracer', (CustomFXTracer, fx.Tracer, object), {}) + instance = super().__new__(new_type) + return instance + + def is_leaf_module(self, m, module_qualified_name: str) -> bool: """ Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX """ + import torch + return ( m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn") From c9d2415d4330e4ca5f23505b69f2f1357b4d88e0 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 10:52:28 -0500 Subject: [PATCH 27/47] tidy up and support only power of 2 scaling for now --- hls4ml/converters/pytorch/convolution.py | 6 - hls4ml/converters/pytorch/core.py | 4 - hls4ml/converters/pytorch_to_hls.py | 33 +----- hls4ml/model/graph.py | 4 +- hls4ml/model/layers.py | 13 ++- hls4ml/model/optimizer/__init__.py | 1 - .../optimizer/passes/brevitas_optimizer.py | 105 +----------------- hls4ml/model/optimizer/passes/quant_opt.py | 40 +------ test/pytest/test_brevitas_parsing.py | 5 +- 9 files changed, 26 insertions(+), 185 deletions(-) diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index a70df8cf39..5a39c9fc4a 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -30,15 +30,11 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c layer['weight_quantizer'] = BrevitasQuantizer( width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) - # for non-power-of-2 scales, instead take the unquantized inputs and set parameters - # so an ApplyAlpha node will be added to the model. Currently broken :/ else: raise Exception( '''Non-power of 2 quantization of weights not supported when injecting brevitas models. Please used QONNX instead.''' ) - # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') - # layer['weight_data'] = class_object.quant_weight().detach().value.numpy() else: layer['weight_data'] = class_object.weight.data.numpy() @@ -117,8 +113,6 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['weight_quantizer'] = BrevitasQuantizer( width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) - # for non-power-of-2 scales, instead take the unquantized inputs and set parameters so an - # ApplyAlpha node will be added to the model. Currently broken :/ else: raise Exception( '''Non-power of 2 quantization of weights not supported when injecting brevitas models. diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 2b08ff7789..9ceb4d1a8a 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -51,15 +51,11 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c layer['weight_quantizer'] = BrevitasQuantizer( width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) - # for non-power-of-2 scales, instead take the quantized inputs and set parameters so an ApplyAlpha node - # will be added to the model. Currently not working :/ else: raise Exception( '''Non-power of 2 quantization of weights not supported when injecting brevitas models. Please used QONNX instead.''' ) - # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') - # layer['weight_data'] = class_object.quant_weight().int().detach().numpy() else: layer['weight_data'] = class_object.weight.data.numpy() diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index bcb25661b5..ed715c7a06 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -2,33 +2,10 @@ import numpy as np +from hls4ml.converters.pytorch.tracer import CustomFXTracer from hls4ml.model import ModelGraph from hls4ml.utils.dependency import requires -# this import conficts with our new lazy imports. Not sure how to handle this otherwise yet - - -class CustomFXTracer: - - def __new__(cls, *args, **kwargs): - import torch.fx as fx - - new_type = type('CustomFXTracer', (CustomFXTracer, fx.Tracer, object), {}) - instance = super().__new__(new_type) - return instance - - def is_leaf_module(self, m, module_qualified_name: str) -> bool: - """ - Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX - """ - import torch - - return ( - m.__module__.startswith("torch.nn") - or m.__module__.startswith("torch.ao.nn") - or m.__module__.startswith("brevitas.nn") - ) and not isinstance(m, torch.nn.Sequential) - class PyTorchModelReader: """ @@ -104,11 +81,12 @@ def convert_uaq_to_apfixed(bitwidth, scale_factor): return (fract_bitwidth, int_bitwidth) +# embed quantization information into the layer dictionary for a Quant layer +# so that this layer can be added to the model def addQuantizationParameters(layer, quant_object, quant_type, act=False): if not act: - print(quant_object.bit_width) + # currently not used, might be use later for non-power-of-2 scales bit_width = int(quant_object.bit_width) - # signed = quant_object.is_signed signed = quant_object.signed scale = float(quant_object.scale) zeropoint = float(quant_object.zero_point) @@ -117,7 +95,6 @@ def addQuantizationParameters(layer, quant_object, quant_type, act=False): else: narrow = False rounding_mode = 'ROUND' - layer['convert_from_brevitas'] = True else: bit_width = int(quant_object.bit_width()) signed = quant_object.is_signed @@ -125,8 +102,6 @@ def addQuantizationParameters(layer, quant_object, quant_type, act=False): zeropoint = float(quant_object.zero_point()) narrow = quant_object.is_narrow_range rounding_mode = quant_object.rounding_mode - layer['convert_io_from_brevitas'] = True - print(scale) layer[f'{quant_type}_quantization'] = { 'bit_width': bit_width, diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 9ad48d718f..cf3bb4362e 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -565,6 +565,9 @@ def remove_node(self, node, rewire=True): if outputs[0] == nxt_inp: next_node.inputs[i] = inputs[0] + if node.outputs[0] in self.outputs: + prev_node = node.get_input_node(node.inputs[0]) + self.outputs[self.outputs.index(node.outputs[0])] = prev_node.outputs[0] del self.output_vars[node.outputs[0]] del self.graph[node.name] @@ -594,7 +597,6 @@ def replace_node(self, old_node, new_node): for i, n in enumerate(node.outputs): if n in repl: node.outputs[i] = repl[n] - self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) old_name = old_node.name diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index aac11cc7a3..3e8b897393 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -386,14 +386,17 @@ def initialize(self): class Quant(Layer): # The QONNX quantization layer """ - This is a QONNX quantization layer. Optimizations should convert it - before HLS is produced. + This is a QONNX quantization layer. Can also be inserted in direct brevitas parsing. + Optimizations should convert it before HLS is produced. """ _expected_attributes = [ Attribute('narrow', value_type=bool), Attribute('rounding_mode', value_type=str), Attribute('signed', value_type=bool), + Attribute('scale', value_type=float), + Attribute('zeropt', value_type=int), + Attribute('bitwidth', value_type=int), ] def initialize(self): @@ -452,6 +455,8 @@ class Dense(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): @@ -500,6 +505,8 @@ class Conv1D(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): @@ -611,6 +618,8 @@ class Conv2D(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index cb206ffc84..b0062c59c8 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -34,7 +34,6 @@ 'parse_brevitas', [ 'brevitas_input_output_optimizer', - 'brevitas_factorize_alpha', ], ) diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index f7ac4ab9ab..79387a37fa 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -1,14 +1,7 @@ -# Conversion of model from channels_first to channels_last data format -# Based on https://github.com/fastmachinelearning/qonnx/blob/ -# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py -import math - +# Inserts Quant nodes into the model as needed for input/output quantization of layers in brevitas import numpy as np -from hls4ml.model.layers import ApplyAlpha from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.optimizer.passes.quant_opt import _calculate_precision_quantizer -from hls4ml.model.types import NamedType, find_minimum_width class BrevitasInputOutputOptimizer(OptimizerPass): @@ -26,7 +19,7 @@ def match(self, node): def transform(self, model, node): # See if Quant layer needs to be added for the output - if 'output_quantization' in node.attributes.keys(): + if 'output_quantization' in node.attributes.keys() and not len(node.attributes['output_quantization']) == 0: attributes = {} @@ -44,9 +37,9 @@ def transform(self, model, node): model.insert_node(quant_node) - node.attributes['convert_io_from_brevitas'] = False + node.attributes['output_quantization'] = {} - elif 'input_quantization' in node.attributes.keys(): + elif 'input_quantization' in node.attributes.keys() and not len(node.attributes['input_quantization']) == 0: attributes = {} @@ -64,94 +57,6 @@ def transform(self, model, node): model.insert_node(quant_node) - node.attributes['convert_io_from_brevitas'] = False - return True - - -class BrevitasFactorizeAlpha(OptimizerPass): - '''OptimizerPass for extracting alpha "scale" from Brevitas quantized layer. - The weights of the Quant{Dense, Conv} layer are scaled to the common data type, - and an 'ApplyAlpha' layer is inserted to reapply the scale. - ''' - - def match(self, node): - q_layer = node.class_name in ['Dense', 'QConv1D', 'Conv2D'] + node.attributes['input_quantization'] = {} - has_w_alpha = 'weight_quantization' in node.attributes.keys() - has_b_alpha = 'bias_quantization' in node.attributes.keys() - - needs_conversion = False - if 'convert_from_brevitas' in node.attributes.keys(): - needs_conversion = node.attributes['convert_from_brevitas'] - - is_match = q_layer and needs_conversion and (has_w_alpha or has_b_alpha) - return is_match - - def transform(self, model, node): - # The quantizer has to be applied to set the scale attribute - # This must be applied to the _unquantized_ weights to obtain the correct scale - if node.attributes['convert_from_brevitas'] is False: - return False - scale = np.full(node.weights['weight'].data.shape, [node.attributes['weight_quantization']['scale']]) - - # find number of bits to represent unscaled weight tensor (should be the full bit width, but better be sure) - # and set precision for weight variable - int_bits = find_minimum_width(node.weights['weight'].data, signed=True) - - unscale_precision, _ = _calculate_precision_quantizer(int_bits, int_bits, True, True, 'FLOOR') - node.weights['weight'].type = NamedType(node.weights['weight'].name + '_t', unscale_precision) - res_precision, _ = _calculate_precision_quantizer(int_bits * 2, int_bits, True, True, 'FLOOR') - node.types['accum_t'] = NamedType(node.name + '_accum_t', res_precision) - node.types['result_t'].type = res_precision - - # Move the biases from the Dense layer to the ApplyAlpha layer - bias = node.weights['bias'].data - node.weights['bias'].data = np.zeros(bias.shape) - - # insert a Batch Normalization layer to apply the alpha scale - if 'Linear' in node.class_name: - n_in = node.get_attr('n_out') - elif 'Conv' in node.class_name: - n_in = node.get_attr('out_width') * node.get_attr('out_height', 1) * node.get_attr('n_filt') - else: - n_in = node.get_attr('n_out') - - # the name of the new ApplyAlpha node - alpha_name = node.get_attr('name') + '_alpha' - - # make the precision auto - alpha_precision = {'Precision': 'auto'} - model.config.set_name_config(alpha_name, alpha_precision) - model.config.parse_name_config(alpha_name, alpha_precision) - - # This part is very stupid, since this basically just results in the scale being represented at 2*bith width, - # otherwise it just uses full system float precision. Needs work - fractional_part, integer_part = math.modf(node.attributes['weight_quantization']['scale']) - if integer_part > 0: - int_bits = math.ceil(math.log2(integer_part)) + 1 - else: - int_bits = 0 - frac_bits = math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) - scale_precision, scale_quantizer = _calculate_precision_quantizer( - int_bits + frac_bits, int_bits, True, False, 'FLOOR' - ) - - attrs = { - 'name': alpha_name, - 'class_name': 'Alpha', - 'inputs': node.outputs, - 'n_in': n_in, - 'n_filt': node.get_attr('n_filt', -1), - 'reuse_factor': node.get_attr('reuse_factor'), - 'scale_data': scale, - 'scale_quantizer': scale_quantizer, - 'scale_precision': scale_precision, - 'bias_data': bias, - 'bias_quantizer': None, - 'bias_precision': None, - 'trace': node.get_attr('trace', False), - } - alpha_layer = model.make_node(ApplyAlpha, node.name + '_alpha', attrs, node.outputs) - model.insert_node(alpha_layer) - node.attributes['convert_from_brevitas'] = False return True diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index ffc3980e33..88453c616c 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -135,7 +135,7 @@ def transform(self, model, node): config = model.config.get_layer_config(node) prec_config = config.setdefault('Precision', {}) prec_config['result'] = str(precision) - new_name = f'{node.name}_act' + new_name = node.name model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) @@ -243,14 +243,7 @@ def transform(self, model, node): act_name = f'{node.name}_act' model.config.set_name_config(act_name, act_config) model.config.parse_name_config(act_name, act_config) - if 'global_out' in node.outputs: - new_node = model.make_node( - Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs] - ) - else: - new_node = model.make_node( - Activation, act_name, activation_attributes, [node.inputs[0]], [x + '_act' for x in node.outputs] - ) + new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs]) model.replace_node(node, new_node) # but now add the ApplyAlhpas before and after @@ -274,46 +267,17 @@ def transform(self, model, node): model.config.parse_name_config(rescale_name, rescale_config) firstscale = 1 / scale - # need to adjust data type to account for the fact that the inverse scale needs mostly integer bits - fractional_part, integer_part = math.modf(firstscale) - int_bits = math.ceil(math.log2(integer_part)) + 1 - frac_bits = min( - node.get_attr('bitwidth'), - math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) + 1, - ) - scale_precision, scale_quantizer = _calculate_precision_quantizer( - int_bits + frac_bits, int_bits, False, False, 'FLOOR' - ) - firstbias = bias attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape) - attributes_scale['scale_quantizer'] = scale_quantizer - attributes_scale['scale_precision'] = scale_precision scale_node = model.make_node(ApplyAlpha, scale_name, attributes_scale, [node.inputs[0]]) - scale_node.types['result_t'].precision = scale_precision model.insert_node(scale_node) - fractional_part, integer_part = math.modf(scale) - if integer_part > 0: - int_bits = math.ceil(math.log2(integer_part)) + 1 - else: - int_bits = 0 - frac_bits = min( - node.get_attr('bitwidth') * 2, - math.ceil(math.log2(fractional_part * (10 ** len(str(fractional_part).split('.')[1])))) + 1, - ) - scale_precision, scale_quantizer = _calculate_precision_quantizer( - int_bits + frac_bits, int_bits, False, False, 'FLOOR' - ) - rescale = scale rebias = -bias * scale attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) - attributes_rescale['scale_quantizer'] = scale_quantizer - attributes_rescale['scale_precision'] = scale_precision rescale_node = model.make_node(ApplyAlpha, rescale_name, attributes_rescale, [new_node.outputs[0]]) model.insert_node(rescale_node) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index f39cc3a6af..3bf9cb55c2 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -45,11 +45,10 @@ def forward(self, x): class QuantModelLinear(Module): def __init__(self, weight_quant, input_quant): super().__init__() - # self.conv1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=quants[weight_quant], input_quant=quants[input_quant]) self.conv1 = qnn.QuantLinear( 4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint, input_quant=Int8ActPerTensorFixedPoint ) - self.relu1 = qnn.QuantReLU() + self.relu1 = qnn.QuantReLU(act_quant=Int8ActPerTensorFixedPoint) def forward(self, x): out = self.relu1(self.conv1(x)) @@ -61,13 +60,11 @@ def forward(self, x): @pytest.mark.parametrize('weight_quant', ['Int8WeightPerTensorFixedPoint']) @pytest.mark.parametrize('io_quant', ['Int8ActPerTensorFixedPoint']) def test_quantlinear(backend, io_type, weight_quant, io_quant): - # def test_quantlinear(backend, io_type): model = QuantModelLinear(weight_quant, io_quant) x = torch.rand(1, 4) pytorch_prediction = model(x).detach().numpy() config = config_from_pytorch_model(model, input_shape=(None, 4)) - # output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}') output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}_{weight_quant}_{io_quant}') hls_model = convert_from_pytorch_model( From 8910fa639d73d59971fd19656e2947d607889fe5 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 10:56:06 -0500 Subject: [PATCH 28/47] add missing file --- hls4ml/converters/pytorch/tracer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 hls4ml/converters/pytorch/tracer.py diff --git a/hls4ml/converters/pytorch/tracer.py b/hls4ml/converters/pytorch/tracer.py new file mode 100644 index 0000000000..babddad0c1 --- /dev/null +++ b/hls4ml/converters/pytorch/tracer.py @@ -0,0 +1,16 @@ +import torch + + +class CustomFXTracer(torch.fx.Tracer): + + def is_leaf_module(self, m, module_qualified_name: str) -> bool: + """ + Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX + """ + import torch + + return ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + or m.__module__.startswith("brevitas.nn") + ) and not isinstance(m, torch.nn.Sequential) From 680728f30a5fbafce4c17cadb1059de25e20962a Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 15:31:37 -0500 Subject: [PATCH 29/47] move tracer import --- hls4ml/converters/pytorch_to_hls.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index ed715c7a06..cbffd51f72 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -2,7 +2,6 @@ import numpy as np -from hls4ml.converters.pytorch.tracer import CustomFXTracer from hls4ml.model import ModelGraph from hls4ml.utils.dependency import requires @@ -173,6 +172,8 @@ def parse_pytorch_model(config, verbose=True): """ import torch + from hls4ml.converters.pytorch.tracer import CustomFXTracer + # This is a list of dictionaries to hold all the layer info we need to generate HLS layer_list = [] From 1319791ddf7e37db9deaa121b7cd751a0d483f52 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 15:55:19 -0500 Subject: [PATCH 30/47] add brevitas to testing environment --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7c67b3edfa..c59f294c90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ optional-dependencies.qkeras = [ optional-dependencies.quartus-report = [ "calmjs-parse", "tabulate" ] optional-dependencies.sr = [ "sympy" ] optional-dependencies.testing = [ + "brevitas", "calmjs-parse", "hgq>=0.2.3", "onnx>=1.4", From 1d08c8ab4aa6788e72f68d1207a916aa79501bdb Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 16:03:03 -0500 Subject: [PATCH 31/47] fix import path in pytests --- test/pytest/test_brevitas_parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 3bf9cb55c2..bec3f94c0e 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -105,7 +105,7 @@ def test_quantconv1d(backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') - from hls4ml.converters.pytorch_to_hls import CustomFXTracer + from hls4ml.converters.pytorch.tracer import CustomFXTracer tracer = CustomFXTracer() traced_model = tracer.trace(model) From d064b0c0fd3f333ad3c19815269318c168a6ad7f Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Feb 2025 18:02:24 -0500 Subject: [PATCH 32/47] fix import path in pytests 2 --- test/pytest/test_brevitas_parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index bec3f94c0e..68b8b93442 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -166,7 +166,7 @@ def test_quantconv2d(backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}') - from hls4ml.converters.pytorch_to_hls import CustomFXTracer + from hls4ml.converters.pytorch.tracer import CustomFXTracer tracer = CustomFXTracer() traced_model = tracer.trace(model) From fb30b35df4ed6851c6be8443aee2be21dd000615 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Sat, 22 Feb 2025 09:58:13 -0500 Subject: [PATCH 33/47] fix revert accidental changes --- hls4ml/model/optimizer/passes/quant_opt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 88453c616c..3fe9773877 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -100,6 +100,7 @@ def match(self, node): scale = node.get_attr('scale') bias = node.get_attr('zeropt') is_match = is_match and (bias == np.zeros_like(bias)).all() + # check if scale is ones-like or a power of two scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: @@ -135,7 +136,7 @@ def transform(self, model, node): config = model.config.get_layer_config(node) prec_config = config.setdefault('Precision', {}) prec_config['result'] = str(precision) - new_name = node.name + new_name = f'{node.name}_act' model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) @@ -231,6 +232,7 @@ def transform(self, model, node): narrow = node.get_attr('narrow') signed = node.get_attr('signed') bitwidth = node.get_attr('bitwidth') + precision, quantizer = _calculate_precision_quantizer(bitwidth, bitwidth, signed, narrow, rounding_mode) activation_attributes = {'activation': 'linear', 'quantizer': quantizer} @@ -243,6 +245,7 @@ def transform(self, model, node): act_name = f'{node.name}_act' model.config.set_name_config(act_name, act_config) model.config.parse_name_config(act_name, act_config) + new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs]) model.replace_node(node, new_node) @@ -266,7 +269,6 @@ def transform(self, model, node): model.config.set_name_config(rescale_name, rescale_config) model.config.parse_name_config(rescale_name, rescale_config) firstscale = 1 / scale - firstbias = bias attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape) From bb36f7199708ce1e4124f8eb73fd63922258c35b Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 24 Feb 2025 08:35:56 -0500 Subject: [PATCH 34/47] remove unnecessary attributes from Quant class --- hls4ml/model/layers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 3e8b897393..b60d4558c7 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -394,9 +394,6 @@ class Quant(Layer): # The QONNX quantization layer Attribute('narrow', value_type=bool), Attribute('rounding_mode', value_type=str), Attribute('signed', value_type=bool), - Attribute('scale', value_type=float), - Attribute('zeropt', value_type=int), - Attribute('bitwidth', value_type=int), ] def initialize(self): From ff4b2e297ea40f38c360550071a67ff6e1aaeb44 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 24 Feb 2025 09:51:47 -0500 Subject: [PATCH 35/47] add support for QuantIdentity layers --- hls4ml/converters/pytorch/core.py | 26 +++++++++++++++++++ .../optimizer/passes/brevitas_optimizer.py | 13 +++++----- hls4ml/model/optimizer/passes/quant_opt.py | 4 +-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 9ceb4d1a8a..98b56731fa 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -22,6 +22,32 @@ def parse_constant_layer(operation, layer_name, node): return layer, output_shape +# A QuantIdentity layer does nothing but quantize its inputs. Insert `Quant` node to be processed by QONNX optimizers +@pytorch_handler('QuantIdentity') +def parse_quantidentity_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert 'QuantIdentity' in operation + + layer = {} + layer['inputs'] = input_names + + layer['class_name'] = 'Quant' + layer['name'] = layer_name + print(input_shapes) + if class_object.act_quant.is_quant_enabled: + layer['bitwidth'] = int(class_object.act_quant.bit_width()) + layer['signed'] = class_object.act_quant.is_signed + layer['scale'] = np.full(np.array(input_shapes[0][1:]), class_object.act_quant.scale()) + layer['zeropt'] = float(class_object.act_quant.zero_point()) + layer['narrow'] = class_object.act_quant.is_narrow_range + layer['rounding_mode'] = class_object.act_quant.rounding_mode + + else: + raise Exception('''QuantIdentify layer without act quant does nothing, please remove from model.''') + output_shape = input_shapes[0] + + return layer, output_shape + + @pytorch_handler('Linear', 'QuantLinear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index 79387a37fa..8ecbb1c68a 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -8,13 +8,12 @@ class BrevitasInputOutputOptimizer(OptimizerPass): '''Takes nodes parsed from brevitas and inserts Quant nodes into the model if necessary''' def match(self, node): - needs_conversion = False - if 'convert_io_from_brevitas' in node.attributes.keys(): - needs_conversion = node.attributes['convert_io_from_brevitas'] and ( - 'output_quantization' in node.attributes.keys() or 'input_quantization' in node.attributes.keys() - ) - - return needs_conversion + if ('output_quantization' in node.attributes.keys() and not len(node.attributes['output_quantization']) == 0) or ( + 'input_quantization' in node.attributes.keys() and not len(node.attributes['input_quantization']) == 0 + ): + return True + else: + return False def transform(self, model, node): diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 3fe9773877..483795c709 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -139,8 +139,8 @@ def transform(self, model, node): new_name = f'{node.name}_act' model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) - - new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + print("making new node") + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [f'{x}_act' for x in node.outputs]) model.replace_node(node, new_node) return True From d91bdf0184a9b89444a2ccc26d924285c279d781 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 24 Feb 2025 12:52:14 -0500 Subject: [PATCH 36/47] fix output names --- hls4ml/model/optimizer/passes/quant_opt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 483795c709..ed33141edd 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -139,7 +139,6 @@ def transform(self, model, node): new_name = f'{node.name}_act' model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) - print("making new node") new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [f'{x}_act' for x in node.outputs]) model.replace_node(node, new_node) From 2f1ce04121ac58423ac3f9ca9567cd19467b14b2 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Tue, 25 Feb 2025 10:30:43 -0500 Subject: [PATCH 37/47] add support for QuantCat layer - which seems broken :/ --- hls4ml/converters/pytorch/merge.py | 10 ++++- .../optimizer/passes/brevitas_optimizer.py | 9 ++-- test/pytest/test_brevitas_parsing.py | 43 +++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/hls4ml/converters/pytorch/merge.py b/hls4ml/converters/pytorch/merge.py index 1f1e11dcb7..7d42b82fa9 100644 --- a/hls4ml/converters/pytorch/merge.py +++ b/hls4ml/converters/pytorch/merge.py @@ -1,6 +1,6 @@ -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, pytorch_handler -concat_layers = ['cat', 'concat', 'concatenate'] +concat_layers = ['cat', 'concat', 'concatenate', 'QuantCat'] @pytorch_handler(*concat_layers) @@ -25,6 +25,12 @@ def parse_concat_layer(operation, layer_name, input_names, input_shapes, node, c output_shape = input_shapes[0][:] output_shape[layer['axis']] += input_shapes[1][layer['axis']] + if "Quant" in layer_name: + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + return layer, output_shape diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index 8ecbb1c68a..afa3d345d2 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -42,7 +42,6 @@ def transform(self, model, node): attributes = {} - input = node.inputs[0] # Other attributes attributes['narrow'] = node.attributes['input_quantization']['narrow'] attributes['rounding_mode'] = node.attributes['input_quantization']['rounding_mode'] @@ -51,8 +50,12 @@ def transform(self, model, node): attributes['zeropt'] = node.attributes['input_quantization']['zeropoint'] attributes['scale'] = np.array([node.attributes['input_quantization']['scale']]) - quant_node = model.make_node('Quant', f'quant_input_for_{node.get_attr("name")}', attributes, [input]) - quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}') + for i, input in enumerate(node.inputs): + + quant_node = model.make_node( + 'Quant', f'quant_input_for_{node.get_attr("name")}_input_{i}', attributes, [input] + ) + quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}') model.insert_node(quant_node) diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 68b8b93442..84d14fb595 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -217,3 +217,46 @@ def test_quantconv2d(backend, io_type): hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) + + +# QuantCat seems to be broken in brevitas itself, disable this test for the moment. +# class QuantModelConcacenate(Module): +# def __init__(self): +# super().__init__() +# self.cat = qnn.QuantCat( +# input_quant=Int8ActPerTensorFixedPoint, output_quant=Int8ActPerTensorFixedPoint +# ) + +# def forward(self, x, y, dim): +# out = self.cat([x,y], dim=dim) +# return out + + +# @pytest.mark.parametrize('backend', ['Vivado']) +# @pytest.mark.parametrize('io_type', ['io_parallel']) +# @pytest.mark.parametrize('dim', [1, 2]) +# def test_concatenate2d(dim, io_type, backend): +# input_shape = (10, 3) + + +# x = torch.randn(input_shape) +# y = torch.randn(input_shape) + +# model = QuantModelConcacenate() +# pytorch_prediction = model(x,y,dim).detach().numpy() +# config = config_from_pytorch_model(model, input_shape=[(None, input_shape[0], input_shape[1]), +# (None, input_shape[0], input_shape[1])]) +# output_dir = str(test_root_path / f'hls4mlprj_brevitas_cat_{backend}_{io_type}_{dim}') + +# hls_model = convert_from_pytorch_model( +# model, +# hls_config=config, +# output_dir=output_dir, +# backend=backend, +# io_type=io_type, +# ) +# hls_model.compile() + +# hls_prediction = np.reshape(hls_model.predict([x.detach().numpy(), y.detach().numpy()]), pytorch_prediction.shape) + +# np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) From 7bf57f1ff4ccee8154f8e3f1ec1f777b48d926cd Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 26 Feb 2025 12:18:25 -0500 Subject: [PATCH 38/47] remove QuantCat, it's deprecated, add support for QuantUpsample in io_parallel --- hls4ml/converters/pytorch/merge.py | 10 +- hls4ml/converters/pytorch/reshape.py | 19 ++- hls4ml/converters/pytorch_to_hls.py | 6 + hls4ml/model/optimizer/__init__.py | 1 + hls4ml/model/optimizer/passes/quant_opt.py | 7 +- .../passes/resize_remove_constants.py | 17 +++ test/pytest/test_brevitas_parsing.py | 119 +++++++++++++----- 7 files changed, 133 insertions(+), 46 deletions(-) diff --git a/hls4ml/converters/pytorch/merge.py b/hls4ml/converters/pytorch/merge.py index 7d42b82fa9..1f1e11dcb7 100644 --- a/hls4ml/converters/pytorch/merge.py +++ b/hls4ml/converters/pytorch/merge.py @@ -1,6 +1,6 @@ -from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, pytorch_handler +from hls4ml.converters.pytorch_to_hls import pytorch_handler -concat_layers = ['cat', 'concat', 'concatenate', 'QuantCat'] +concat_layers = ['cat', 'concat', 'concatenate'] @pytorch_handler(*concat_layers) @@ -25,12 +25,6 @@ def parse_concat_layer(operation, layer_name, input_names, input_shapes, node, c output_shape = input_shapes[0][:] output_shape[layer['axis']] += input_shapes[1][layer['axis']] - if "Quant" in layer_name: - if class_object.input_quant.is_quant_enabled: - layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) - if class_object.output_quant.is_quant_enabled: - layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) - return layer, output_shape diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index f7392ab8da..8bf7a6d5cd 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -120,10 +120,25 @@ def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, return layer, output_shape -@pytorch_handler('Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d') +@pytorch_handler( + 'Upsample', + 'UpsamplingNearest2d', + 'UpsamplingBilinear2d', + 'QuantUpsample', + 'QuantUpsamplingNearest2d', + 'QuantUpsamplingBilinear2d', +) def handle_upsample(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): - assert operation in ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] + assert operation in [ + 'Upsample', + 'UpsamplingNearest2d', + 'UpsamplingBilinear2d', + 'QuantUpsample', + 'QuantUpsamplingNearest2d', + 'QuantUpsamplingBilinear2d', + ] + layer = {} layer['name'] = layer_name layer['inputs'] = input_names diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index cbffd51f72..92fea38111 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -88,6 +88,7 @@ def addQuantizationParameters(layer, quant_object, quant_type, act=False): bit_width = int(quant_object.bit_width) signed = quant_object.signed scale = float(quant_object.scale) + print("scale: ", scale) zeropoint = float(quant_object.zero_point) if signed: narrow = True @@ -98,6 +99,7 @@ def addQuantizationParameters(layer, quant_object, quant_type, act=False): bit_width = int(quant_object.bit_width()) signed = quant_object.is_signed scale = float(quant_object.scale()) + print("scale: ", scale) zeropoint = float(quant_object.zero_point()) narrow = quant_object.is_narrow_range rounding_mode = quant_object.rounding_mode @@ -252,6 +254,10 @@ def parse_pytorch_model(config, verbose=True): if pytorch_class not in supported_layers: raise Exception(f'Unsupported layer {pytorch_class}') + if 'IOType' in config.keys(): + if "QuantUpsampl" in pytorch_class and config['IOType'] == 'io_stream': + raise Exception('Quant upsampling layers currently not supported with io_stream') + if layer_counter != 0: input_shapes = [output_shape] # In case there are multiple inputs diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index b0062c59c8..b1ea9cd75d 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -84,6 +84,7 @@ 'merge_linear_activation', # many of the above optimzers need to be done before this 'infer_precision_types', + 'adjust_resize_input_precision', ], requires=['parse_qonnx'], ) # TODO Maybe not all QKeras optmizers belong here? diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index ed33141edd..538d77f4f2 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -88,7 +88,6 @@ class QuantToActivation(OptimizerPass): def match(self, node): # only matches after the other inputs are already folded - is_match = ( isinstance(node, Quant) and len(node.inputs) == 1 @@ -105,8 +104,8 @@ def match(self, node): scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: # This optimization only works if all scales are the same - if np.all(scale[0] == scale): - mantissa, _ = np.frexp(scale[0]) + if np.all(next(iter(scale.flat)) == scale): + mantissa, _ = np.frexp(next(iter(scale.flat))) scale_unit_or_po2 = mantissa == 0.5 is_match = scale_unit_or_po2 @@ -125,7 +124,7 @@ def transform(self, model, node): integer = bitwidth scale = node.get_attr('scale') if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): - _, exp = np.frexp(scale[0]) + _, exp = np.frexp(next(iter(scale.flat))) integer = bitwidth + exp - 1 precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) diff --git a/hls4ml/model/optimizer/passes/resize_remove_constants.py b/hls4ml/model/optimizer/passes/resize_remove_constants.py index 69039c60a2..3469811962 100644 --- a/hls4ml/model/optimizer/passes/resize_remove_constants.py +++ b/hls4ml/model/optimizer/passes/resize_remove_constants.py @@ -36,3 +36,20 @@ def transform(self, model, node): # Clean all the '' inputs node.inputs = list(filter(None, node.inputs)) return True + + +class AdjustResizeInputPrecision(OptimizerPass): + """ + This optimizer makes sure that the input data type of a Resize layer matches the output data type of the previous layer. + """ + + def match(self, node): + is_match = isinstance(node, Resize) and not ( + node.get_input_node().types['result_t'].precision == node.get_output_variable().type.precision + ) + return is_match + + def transform(self, model, node): + node.get_output_variable().type.precision = node.get_input_node().types['result_t'].precision + + return True diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 84d14fb595..b2160262e1 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -8,6 +8,7 @@ from torch import nn from torch.nn import Module +import hls4ml from hls4ml.converters import convert_from_pytorch_model from hls4ml.utils.config import config_from_pytorch_model @@ -219,44 +220,98 @@ def test_quantconv2d(backend, io_type): np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) -# QuantCat seems to be broken in brevitas itself, disable this test for the moment. -# class QuantModelConcacenate(Module): -# def __init__(self): -# super().__init__() -# self.cat = qnn.QuantCat( -# input_quant=Int8ActPerTensorFixedPoint, output_quant=Int8ActPerTensorFixedPoint -# ) +in_height = 6 +in_width = 8 +in_feat = 4 -# def forward(self, x, y, dim): -# out = self.cat([x,y], dim=dim) -# return out +size = 2 +atol = 5e-3 -# @pytest.mark.parametrize('backend', ['Vivado']) -# @pytest.mark.parametrize('io_type', ['io_parallel']) -# @pytest.mark.parametrize('dim', [1, 2]) -# def test_concatenate2d(dim, io_type, backend): -# input_shape = (10, 3) +@pytest.fixture(scope='module') +def data_1d(): + X = np.random.rand(100, in_feat, in_width) + return X -# x = torch.randn(input_shape) -# y = torch.randn(input_shape) +@pytest.fixture(scope='module') +def data_2d(): + X = np.random.rand(100, in_feat, in_height, in_width) + return X -# model = QuantModelConcacenate() -# pytorch_prediction = model(x,y,dim).detach().numpy() -# config = config_from_pytorch_model(model, input_shape=[(None, input_shape[0], input_shape[1]), -# (None, input_shape[0], input_shape[1])]) -# output_dir = str(test_root_path / f'hls4mlprj_brevitas_cat_{backend}_{io_type}_{dim}') -# hls_model = convert_from_pytorch_model( -# model, -# hls_config=config, -# output_dir=output_dir, -# backend=backend, -# io_type=io_type, -# ) -# hls_model.compile() +class QuantUpsample1DModel(nn.Module): + def __init__(self): + super().__init__() + self.identity = qnn.QuantIdentity(act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) + self.upsample = qnn.QuantUpsample(scale_factor=2) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.upsample(self.identity(x))) + + +class QuantUpsample2DModel(nn.Module): + def __init__(self): + super().__init__() + # this scale_factor tests proper output shape calculation with fractional scaling and parsing per-axis scales + self.identity = qnn.QuantIdentity(act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) + self.upsample = qnn.QuantUpsamplingNearest2d(scale_factor=(1, 2.4)) # Would also work with Upsample(mode='nearest') + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.upsample(self.identity(x))) + + +@pytest.mark.parametrize('io_type', ['io_parallel']) # Quant upsampling layers currently not supported in io_stream +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_pytorch_upsampling1d(data_1d, io_type, backend): + model = QuantUpsample1DModel() + + config = hls4ml.utils.config_from_pytorch_model( + model, + (None, in_feat, in_width), + default_precision='ap_fixed<16,6>', + channels_last_conversion="internal", + transpose_outputs=False, + ) + odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + data_1d_t = np.ascontiguousarray(data_1d.transpose([0, 2, 1])) + + pytorch_prediction = model(torch.Tensor(data_1d)).value.detach().numpy() + hls_prediction = hls_model.predict(data_1d_t) + + pred_shape = list(pytorch_prediction.shape) + pred_shape.append(pred_shape.pop(1)) # Transpose shape to channels_last + hls_prediction = hls_prediction.reshape(pred_shape).transpose([0, 2, 1]) # Transpose back + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + -# hls_prediction = np.reshape(hls_model.predict([x.detach().numpy(), y.detach().numpy()]), pytorch_prediction.shape) +@pytest.mark.parametrize('io_type', ['io_parallel']) # Fractional scaling doesn't work with io_stream +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_pytorch_upsampling2d(data_2d, io_type, backend): + model = QuantUpsample2DModel() -# np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) + config = hls4ml.utils.config_from_pytorch_model( + model, + (in_feat, in_height, in_width), + default_precision='ap_fixed<16,6>', + channels_last_conversion="full", # With conversion to channels_last + transpose_outputs=True, + ) + odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + pytorch_prediction = model(torch.Tensor(data_2d)).value.detach().numpy().flatten() + hls_prediction = hls_model.predict(data_2d).flatten() + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) From c5f96c67825de918080bfcf5a6d846fb279825f9 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Thu, 27 Feb 2025 12:30:18 -0500 Subject: [PATCH 39/47] add support for QuantEltWiseAdd --- hls4ml/converters/pytorch/core.py | 2 +- hls4ml/converters/pytorch/merge.py | 10 +++- hls4ml/converters/pytorch_to_hls.py | 7 +-- .../optimizer/passes/brevitas_optimizer.py | 3 +- .../passes/convert_to_channels_last.py | 2 +- test/pytest/test_brevitas_parsing.py | 48 +++++++++++++++++-- 6 files changed, 58 insertions(+), 14 deletions(-) diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 98b56731fa..e4d73263b6 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -32,7 +32,7 @@ def parse_quantidentity_layer(operation, layer_name, input_names, input_shapes, layer['class_name'] = 'Quant' layer['name'] = layer_name - print(input_shapes) + if class_object.act_quant.is_quant_enabled: layer['bitwidth'] = int(class_object.act_quant.bit_width()) layer['signed'] = class_object.act_quant.is_signed diff --git a/hls4ml/converters/pytorch/merge.py b/hls4ml/converters/pytorch/merge.py index 1f1e11dcb7..faae193ac9 100644 --- a/hls4ml/converters/pytorch/merge.py +++ b/hls4ml/converters/pytorch/merge.py @@ -1,4 +1,4 @@ -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, pytorch_handler concat_layers = ['cat', 'concat', 'concatenate'] @@ -28,7 +28,7 @@ def parse_concat_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -add_layers = ['add'] +add_layers = ['add', 'QuantEltwiseAdd'] multiply_layers = ['mul', 'multiply'] subtract_layers = ['sub', 'subtract'] min_layers = ['fmin', 'minimum'] @@ -56,6 +56,12 @@ def parse_merge_layer(operation, layer_name, input_names, input_shapes, node, cl layer['inputs'] = input_names + if 'Quant' in operation: + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True, scale_up=True) + output_shape = input_shapes[0][:] return layer, output_shape diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index cf2eda677f..9cf5b379dd 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -82,13 +82,12 @@ def convert_uaq_to_apfixed(bitwidth, scale_factor): # embed quantization information into the layer dictionary for a Quant layer # so that this layer can be added to the model -def addQuantizationParameters(layer, quant_object, quant_type, act=False): +def addQuantizationParameters(layer, quant_object, quant_type, act=False, scale_up=False): if not act: # currently not used, might be use later for non-power-of-2 scales bit_width = int(quant_object.bit_width) signed = quant_object.signed scale = float(quant_object.scale) - print("scale: ", scale) zeropoint = float(quant_object.zero_point) if signed: narrow = True @@ -99,7 +98,9 @@ def addQuantizationParameters(layer, quant_object, quant_type, act=False): bit_width = int(quant_object.bit_width()) signed = quant_object.is_signed scale = float(quant_object.scale()) - print("scale: ", scale) + # bit of a hack to make adding operations with QuantEltWiseAdd work + if scale_up: + scale = 2 ** (math.log2(scale) + 1) zeropoint = float(quant_object.zero_point()) narrow = quant_object.is_narrow_range rounding_mode = quant_object.rounding_mode diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index afa3d345d2..33ba4985ad 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -51,13 +51,12 @@ def transform(self, model, node): attributes['scale'] = np.array([node.attributes['input_quantization']['scale']]) for i, input in enumerate(node.inputs): - quant_node = model.make_node( 'Quant', f'quant_input_for_{node.get_attr("name")}_input_{i}', attributes, [input] ) quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}') - model.insert_node(quant_node) + model.insert_node(quant_node) node.attributes['input_quantization'] = {} diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 6511a6967b..24abf34e48 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -116,7 +116,7 @@ def transform(self, model, node): # Add transpose for output layer elif ( - node.get_attr('name') in model.outputs + node.name in model.outputs and len(outshape) > 1 and model.config.config['HLSConfig']['Model']['TransposeOutputs'] ): diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index b2160262e1..293620e01c 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -46,13 +46,11 @@ def forward(self, x): class QuantModelLinear(Module): def __init__(self, weight_quant, input_quant): super().__init__() - self.conv1 = qnn.QuantLinear( - 4, 4, bias=False, weight_quant=Int8WeightPerTensorFixedPoint, input_quant=Int8ActPerTensorFixedPoint - ) - self.relu1 = qnn.QuantReLU(act_quant=Int8ActPerTensorFixedPoint) + self.lin1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=quants[weight_quant], input_quant=quants[input_quant]) + self.relu1 = qnn.QuantReLU(act_quant=quants[input_quant]) def forward(self, x): - out = self.relu1(self.conv1(x)) + out = self.relu1(self.lin1(x)) return out @@ -315,3 +313,43 @@ def test_pytorch_upsampling2d(data_2d, io_type, backend): hls_prediction = hls_model.predict(data_2d).flatten() np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + + +class QuantEltwiseAddModel(nn.Module): + def __init__(self): + super().__init__() + self.add = qnn.QuantEltwiseAdd(input_quant=Int8ActPerTensorFixedPoint, output_quant=Int8ActPerTensorFixedPoint) + + def forward(self, x, y): + return self.add(x, y) + + +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_brevitas_quanteltwiseadd(io_type, backend): + model = QuantEltwiseAddModel() + + x = torch.rand(1, 4, 4) + y = torch.rand(1, 4, 4) + + pytorch_prediction = model(torch.Tensor(x), torch.Tensor(y)).detach().numpy() + + config = hls4ml.utils.config_from_pytorch_model( + model, + [(None, 4, 4), (None, 4, 4)], + default_precision='ap_fixed<16,6>', + channels_last_conversion="off", + transpose_outputs=False, + ) + odir = str(test_root_path / f'hls4mlprj_brevitas_quanteltwiseadd_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + hls_prediction = hls_model.predict([x.detach().numpy(), y.detach().numpy()]) + + pred_shape = pytorch_prediction.shape + hls_prediction = hls_prediction.reshape(pred_shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=5e-2, atol=0.05) From 05d99fac485927cd1c89bb1930b641b6081c3d1b Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Thu, 6 Mar 2025 11:03:43 -0500 Subject: [PATCH 40/47] first try at RNN layers --- .../vivado/passes/recurrent_templates.py | 4 +- hls4ml/converters/pytorch/pooling.py | 2 - hls4ml/converters/pytorch/recurrent.py | 317 +++++++++++++++++- hls4ml/converters/pytorch_to_hls.py | 2 +- hls4ml/model/graph.py | 4 +- hls4ml/model/layers.py | 29 +- .../optimizer/passes/brevitas_optimizer.py | 4 +- .../firmware/nnet_utils/nnet_recurrent.h | 2 +- .../vivado/nnet_utils/nnet_recurrent.h | 7 +- hls4ml/writer/vivado_writer.py | 3 +- 10 files changed, 352 insertions(+), 22 deletions(-) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index a4bd649efa..b589188d45 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -47,7 +47,7 @@ static const unsigned table_size = {table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; - typedef {table_t.name} table_t; + typedef {act_t.name} table_t; }};\n""" recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{ @@ -55,7 +55,7 @@ static const unsigned table_size = {table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; - typedef {table_t.name} table_t; + typedef {recurr_act_t.name} table_t; }};\n""" # LSTM + GRU templates diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 882a7f7aed..b43699a72a 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -6,8 +6,6 @@ 'MaxPool2d', 'AvgPool1d', 'AvgPool2d', - 'QuantMaxPool1d', - 'QuantMaxPool2d', ] # TODO add support for special quantized average pool layers diff --git a/hls4ml/converters/pytorch/recurrent.py b/hls4ml/converters/pytorch/recurrent.py index 5d8f6a58bd..b1d810dd57 100644 --- a/hls4ml/converters/pytorch/recurrent.py +++ b/hls4ml/converters/pytorch/recurrent.py @@ -1,6 +1,8 @@ import numpy as np -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType, NamedType rnn_layers = ['RNN', 'LSTM', 'GRU'] @@ -72,3 +74,316 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas layer['pass_initial_states'] = True return layer, output_shape + + +quant_rnn_layers = ['QuantRNN', 'QuantLSTM'] # No QuantGRU in brevitas at this point + + +@pytorch_handler(*quant_rnn_layers) +def parse_quant_rnn_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation in quant_rnn_layers + operation = operation.split('Quant')[-1] + + if len(class_object._modules['layers']) > 1: + raise Exception('hls4ml does not support num_layers > 1') + + if class_object.num_directions > 1: + raise Exception('hls4ml does not support birectional RNNs') + + layer = {} + + layer["name"] = layer_name + + layer['inputs'] = input_names + if 'IOType' in config.keys(): + if len(input_names) > 1 and config['IOType'] == 'io_stream': + raise Exception('Passing initial values for the hidden state is not supported for io_stream input type.') + + layer['class_name'] = operation + if operation == 'RNN': + layer['class_name'] = 'SimpleRNN' + + layer['return_sequences'] = False # parameter does not exist in pytorch + layer['return_state'] = False # parameter does not exist in pytorch + + if layer['class_name'] == 'SimpleRNN': + layer['activation'] = 'tanh' if 'Tanh' in str(class_object._modules['layers'][0][0].cell.act_fn) else 'ReLU' + else: + layer['activation'] = 'tanh' # GRU and LSTM are hard-coded to use tanh in pytorch + + if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM': + layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch + + layer['time_major'] = not class_object._modules['layers'][0][0].cell.batch_first + # TODO Should we handle time_major? + if layer['time_major']: + raise Exception('hls4ml only supports "batch-first == True"') + + layer['n_timesteps'] = input_shapes[0][1] + layer['n_in'] = input_shapes[0][2] + + layer['n_out'] = class_object._modules['layers'][0][0].hidden_size + + if 'LSTM' in operation: + LSTMObject = class_object._modules['layers'][0][0] + + input_weight = LSTMObject.input_gate_params.input_weight + forget_weight = LSTMObject.forget_gate_params.input_weight + cell_weight = LSTMObject.cell_gate_params.input_weight + output_weight = LSTMObject.output_gate_params.input_weight + + input_hidden_weight = LSTMObject.input_gate_params.hidden_weight + forget_hidden_weight = LSTMObject.forget_gate_params.hidden_weight + cell_hidden_weight = LSTMObject.cell_gate_params.hidden_weight + output_hidden_weight = LSTMObject.output_gate_params.hidden_weight + + width = int(input_weight.quant_weight().bit_width) + scale = input_weight.quant_weight().scale.detach().numpy() + + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, scale) + combined_weight = np.concatenate( + ( + input_weight.quant_weight().detach().value.numpy(), + forget_weight.quant_weight().detach().value.numpy(), + cell_weight.quant_weight().detach().value.numpy(), + output_weight.quant_weight().detach().value.numpy(), + ), + axis=0, + ) + layer['weight_data'] = combined_weight + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + width = int(input_hidden_weight.quant_weight().bit_width) + scale = input_hidden_weight.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, scale) + + combined_hidden_weight = np.concatenate( + ( + input_hidden_weight.quant_weight().detach().value.numpy(), + forget_hidden_weight.quant_weight().detach().value.numpy(), + cell_hidden_weight.quant_weight().detach().value.numpy(), + output_hidden_weight.quant_weight().detach().value.numpy(), + ), + axis=0, + ) + + layer['recurrent_weight_data'] = combined_hidden_weight + layer['recurrent_weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + input_bias = LSTMObject.input_gate_params.quant_bias() + forget_bias = LSTMObject.forget_gate_params.quant_bias() + cell_bias = LSTMObject.cell_gate_params.quant_bias() + output_bias = LSTMObject.output_gate_params.quant_bias() + + if input_bias is not None: + width = int(input_bias.bit_width) + scale = input_bias.scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, scale) + + combined_hidden_weight = np.concatenate( + ( + input_bias.detach().value.numpy(), + forget_bias.detach().value.numpy(), + cell_bias.detach().value.numpy(), + output_bias.detach().value.numpy(), + ), + axis=0, + ) + + layer['bias_data'] = combined_hidden_weight + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + else: + layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) + layer['bias_quantizer'] = layer['weight_quantizer'] + + layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) + layer['recurrent_bias_quantizer'] = layer['bias_quantizer'] + + acc_scale = LSTMObject.cell.forget_acc_quant.scale() + acc_bitwdith = int(LSTMObject.cell.forget_acc_quant.bit_width()) + mantissa, _ = np.frexp(acc_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) + + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + tanh_scale = LSTMObject.cell.cell_tanh_quant.scale() + tanh_bitwdith = int(LSTMObject.cell.cell_tanh_quant.bit_width()) + mantissa, _ = np.frexp(tanh_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(tanh_bitwdith, tanh_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['act_t'] = NamedType(layer["name"] + '_act_t', precision) + + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + sigmoid_scale = LSTMObject.cell.cell_tanh_quant.scale() + sigmoid_bitwdith = int(LSTMObject.cell.cell_tanh_quant.bit_width()) + mantissa, _ = np.frexp(sigmoid_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(sigmoid_bitwdith, sigmoid_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['recurr_act_t'] = NamedType(layer["name"] + '_recurr_act_t', precision) + + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + if LSTMObject.cell.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, LSTMObject.cell.output_quant, 'output', act=True) + layer = addQuantizationParameters(layer, LSTMObject.cell.output_quant, 'input', act=True) + + else: + + RNNObject = class_object._modules['layers'][0][0] + + if RNNObject.gate_params.input_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.input_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.input_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed( + width, float(RNNObject.gate_params.input_weight.quant_weight().scale) + ) + layer['weight_data'] = RNNObject.gate_params.input_weight.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + if RNNObject.gate_params.hidden_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.hidden_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.hidden_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed( + width, float(RNNObject.gate_params.hidden_weight.quant_weight().scale) + ) + layer['recurrent_weight_data'] = RNNObject.gate_params.hidden_weight.quant_weight().detach().value.numpy() + layer['recurrent_weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + input_bias = RNNObject.gate_params.quant_bias() + if input_bias is not None: + width = int(input_bias.bit_width) + scale = input_bias.scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, scale) + + layer['bias_data'] = input_bias.detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + else: + layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) + layer['bias_quantizer'] = layer['weight_quantizer'] + + layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) + layer['recurrent_bias_quantizer'] = layer['weight_quantizer'] + + acc_scale = RNNObject.cell.gate_acc_quant.scale() + acc_bitwdith = int(RNNObject.cell.gate_acc_quant.bit_width()) + mantissa, _ = np.frexp(acc_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) + + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + if RNNObject.cell.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'output', act=True) + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'input', act=True) + + if layer['class_name'] == 'GRU': + layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter + + output_shape = [input_shapes[0][0], layer['n_out']] + + layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations + if len(input_names) == 1: + layer['pass_initial_states'] = False + else: + layer['pass_initial_states'] = True + + return layer, output_shape diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 9cf5b379dd..99edfd9f2e 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -283,7 +283,7 @@ def parse_pytorch_model(config, verbose=True): # parse info from class object input_names = [inputs_map.get(str(i), str(i)) for i in node.args] - if pytorch_class in ["RNN", "GRU", "LSTM"]: + if pytorch_class in ['RNN', 'GRU', 'LSTM', 'QuantRNN', 'QuantLSTM']: input_shapes = [] input_names = [] for arg in node.args: diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index cf3bb4362e..6a940caff4 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -494,7 +494,9 @@ def insert_node(self, node, before=None, input_idx=0): next_nodes.append(x) if before is None: - next_node = next((x for x in self.graph.values() if x.inputs and x.inputs[0] in prev_node.outputs), None) + next_node = next( + (x for x in self.graph.values() if x.inputs and set(x.inputs).intersection(prev_node.outputs)), None + ) else: if before not in next_nodes: raise Exception( diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index cb32ab439f..ab2cc5bdfc 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1293,6 +1293,7 @@ class SimpleRNN(Layer): TypeAttribute('weight'), TypeAttribute('bias'), TypeAttribute('recurrent_weight'), + TypeAttribute('accum_t'), ] def initialize(self): @@ -1316,15 +1317,19 @@ def initialize(self): ) # weights - self.add_weights() + self.add_weights(quantizer=self.get_attr('weight_quantizer')) # recurrent weights - self.add_weights_variable(name='recurrent_weight', var_name='wr{index}') + self.add_weights_variable( + name='recurrent_weight', var_name='wr{index}', quantizer=self.get_attr('recurrent_weight_quantizer') + ) # biases - self.add_weights_variable(name='bias', var_name='b{index}') + self.add_weights_variable(name='bias', var_name='b{index}', quantizer=self.get_attr('bias_quantizer')) if "pytorch" in self.attributes.keys(): - self.add_weights_variable(name='recurrent_bias', var_name='br{index}') + self.add_weights_variable( + name='recurrent_bias', var_name='br{index}', quantizer=self.get_attr('recurrent_bias_quantizer') + ) class LSTM(Layer): @@ -1345,6 +1350,7 @@ class LSTM(Layer): TypeAttribute('bias'), TypeAttribute('recurrent_weight'), TypeAttribute('recurrent_bias'), + TypeAttribute('accum_t'), ] def initialize(self): @@ -1368,17 +1374,24 @@ def initialize(self): ) # weights - self.add_weights() + self.add_weights(quantizer=self.get_attr('weight_quantizer')) # recurrent weights recurrent_weight = self.get_attr('recurrent_weight_data') - self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) + self.add_weights_variable( + name='recurrent_weight', + var_name='wr{index}', + data=recurrent_weight, + quantizer=self.get_attr('recurrent_weight_quantizer'), + ) # biases - self.add_weights_variable(name='bias', var_name='b{index}') + self.add_weights_variable(name='bias', var_name='b{index}', quantizer=self.get_attr('bias_quantizer')) if "pytorch" in self.attributes.keys(): - self.add_weights_variable(name='recurrent_bias', var_name='br{index}') + self.add_weights_variable( + name='recurrent_bias', var_name='br{index}', quantizer=self.get_attr('recurrent_bias_quantizer') + ) else: recurrent_bias = np.zeros(recurrent_weight.shape[1]) self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias) diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index 33ba4985ad..294e00a20d 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -54,9 +54,9 @@ def transform(self, model, node): quant_node = model.make_node( 'Quant', f'quant_input_for_{node.get_attr("name")}_input_{i}', attributes, [input] ) - quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}') + quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}_input_{i}') - model.insert_node(quant_node) + model.insert_node(quant_node, input_idx=i) node.attributes['input_quantization'] = {} diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h index d3411f351b..794e46972e 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -501,7 +501,7 @@ void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], h_T } // Do SimpleRNN - simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); + simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); // Write result #pragma unroll diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h index 618767dcb5..8d36a6e519 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -73,8 +73,9 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG #pragma HLS ARRAY_PARTITION variable=inputacc_c complete #pragma HLS ARRAY_PARTITION variable=s_actstate complete - nnet::dense(data, tmpres, param, param_b); - nnet::dense(h_newstate, tmpres_state, param_r, param_br); + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_newstate, tmpres_state, param_r, + param_br); for (int iacc = 0; iacc < (3 * CONFIG_T::n_state); iacc++) { #pragma HLS UNROLL @@ -254,7 +255,7 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], h_T h_newsta data_in[j] = data[j + iloop * CONFIG_T::n_in]; } - nnet::lstm(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); + nnet::lstm(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); if (CONFIG_T::n_sequence_out > 1) for (int i = CONFIG_T::n_state * iloop, j = 0; i < (CONFIG_T::n_state * (iloop + 1)); i++, j++) { #pragma HLS UNROLL diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index 7c99e15b95..c0907acfaf 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -119,7 +119,8 @@ def write_project_cpp(self, model): Args: model (ModelGraph): the hls4ml model. """ - + for name in model.graph: + print(name) filedir = os.path.dirname(os.path.abspath(__file__)) f = open(os.path.join(filedir, '../templates/vivado/firmware/myproject.cpp')) From 7ad2709b5e067ec0b4cd781518f4ae3eda994547 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Thu, 6 Mar 2025 11:07:34 -0500 Subject: [PATCH 41/47] remove leftover print statement --- hls4ml/writer/vivado_writer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hls4ml/writer/vivado_writer.py b/hls4ml/writer/vivado_writer.py index c0907acfaf..7c99e15b95 100644 --- a/hls4ml/writer/vivado_writer.py +++ b/hls4ml/writer/vivado_writer.py @@ -119,8 +119,7 @@ def write_project_cpp(self, model): Args: model (ModelGraph): the hls4ml model. """ - for name in model.graph: - print(name) + filedir = os.path.dirname(os.path.abspath(__file__)) f = open(os.path.join(filedir, '../templates/vivado/firmware/myproject.cpp')) From c3912bcf2b7f92e2b836ebad07f911f3744e6965 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 21 Mar 2025 14:37:13 -0400 Subject: [PATCH 42/47] add tests for QuantRNN and QuantLSTM layers. Accuracy still quite bad --- hls4ml/converters/pytorch/core.py | 7 - .../firmware/nnet_utils/nnet_recurrent.h | 176 +++++++++++++++- .../vivado/nnet_utils/nnet_recurrent.h | 4 +- test/pytest/test_recurrent_brevitas.py | 191 ++++++++++++++++++ 4 files changed, 366 insertions(+), 12 deletions(-) create mode 100644 test/pytest/test_recurrent_brevitas.py diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index e4d73263b6..958a4aae55 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -102,13 +102,6 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c if class_object.output_quant.is_quant_enabled: layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) - else: - layer['weight_data'] = class_object.weight.data.numpy() - if class_object.bias is not None: - layer['bias_data'] = class_object.bias.data.numpy() - else: - layer['bias_data'] = None - if class_object is not None: layer['n_in'] = class_object.in_features layer['n_out'] = class_object.out_features diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h index 794e46972e..53b4d79618 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -14,7 +14,7 @@ namespace nnet { template void multiply_W(data_T input[N_IN], res_T out[N_OUT], const weight_t weight[N_IN * N_OUT]) { MULTIPLY_W_LOOP_I: - #pragma unroll + #pragma unroll, for (int i = 0; i < N_OUT; i++) { out[i] = 0; @@ -58,6 +58,14 @@ template void multiply_vectors(data_T in1[N], } } +template void multiply_vectors(data_T in1[N], s_T in2[N], res_T out[N]) { +MULTIPLY_VECT_LOOP: + #pragma unroll + for (int i = 0; i < N; i++) { + out[i] = in1[i] * in2[i]; + } +} + template void add_vectors(data_T in1[N], data_T in2[N], res_T out[N]) { ADD_VECTOR_LOOP: #pragma unroll @@ -718,6 +726,168 @@ void lstm_cell(data_T inputs[CONFIG_T::n_in], res_T hidden_state[CONFIG_T::n_out } } +template +void lstm_cell(data_T inputs[CONFIG_T::n_in], h_T hidden_state[CONFIG_T::n_out], h_T hidden_state_o[CONFIG_T::n_out], + s_T cell_state[CONFIG_T::n_out], s_T cell_state_o[CONFIG_T::n_out], + const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out], + const typename CONFIG_T::weight_t WF[CONFIG_T::n_in * CONFIG_T::n_out], + const typename CONFIG_T::weight_t WC[CONFIG_T::n_in * CONFIG_T::n_out], + const typename CONFIG_T::weight_t WO[CONFIG_T::n_in * CONFIG_T::n_out], + const typename CONFIG_T::weight_t RWI[CONFIG_T::n_out * CONFIG_T::n_out], + const typename CONFIG_T::weight_t RWF[CONFIG_T::n_out * CONFIG_T::n_out], + const typename CONFIG_T::weight_t RWC[CONFIG_T::n_out * CONFIG_T::n_out], + const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out], + const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out], + const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) { + + // Internals definitions + typename CONFIG_T::accum_t i_afterW[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t i_afterBias[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t c_afterW[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t c_afterBias[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t o_afterW[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t o_afterBias[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t f_afterW[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t f_afterBias[CONFIG_T::n_out] hls_register; + + // Hidden state Gate candidates, intermediate variables + typename CONFIG_T::accum_t i_hiddenCand[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t f_hiddenCand[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t c_hiddenCand[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t o_hiddenCand[CONFIG_T::n_out] hls_register; + + // After addition, intermediate variables + typename CONFIG_T::accum_t i_afterAdd[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t f_afterAdd[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t c_afterAdd[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t o_afterAdd[CONFIG_T::n_out] hls_register; + + // Gate outputs + typename CONFIG_T::accum_t gate_i[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t gate_f[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t gate_c[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t gate_o[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t gate_ic[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t gate_forget[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t h[CONFIG_T::n_out] hls_register; + + // Intermediate variable cell calculation + typename CONFIG_T::accum_t cell_act_multp[CONFIG_T::n_out] hls_register; + typename CONFIG_T::accum_t cell_act_add[CONFIG_T::n_out] hls_register; + + //-----------Gate I Calculations + // Weight multiplication + multiply_W( + inputs, i_afterW, WI); + + // Bias addition + add_bias( + i_afterW, i_afterBias, BI); + + // Hidden Candidate + multiply_U(hidden_state, i_hiddenCand, + RWI); + + // Vector addition + add_vectors(i_afterBias, i_hiddenCand, + i_afterAdd); + + // Activation + CONFIG_T::template activation_recr::activation(i_afterAdd, gate_i); + + //-----------Gate F Calculations + // Weight multiplication + multiply_W( + inputs, f_afterW, WF); + + // Bias addition + add_bias( + f_afterW, f_afterBias, BF); + + // Hidden Candidate + multiply_U(hidden_state, f_hiddenCand, + RWF); + + // Vector addition + add_vectors(f_afterBias, f_hiddenCand, + f_afterAdd); + + // Activation + CONFIG_T::template activation_recr::activation(f_afterAdd, gate_f); + + //-----------Gate C Calculations + // Weight multiplication + multiply_W( + inputs, c_afterW, WC); + + // Bias addition + add_bias( + c_afterW, c_afterBias, BC); + + // Hidden Candidate + multiply_U(hidden_state, c_hiddenCand, + RWC); + + // Vector addition + add_vectors(c_afterBias, c_hiddenCand, + c_afterAdd); + + // Activation + CONFIG_T::template activation::activation(c_afterAdd, gate_c); + + //-----------gate I and C multiply + // Vector multiplication + multiply_vectors(gate_i, gate_c, gate_ic); + + //-----------Gate O Calculations + // Weight multiplication + multiply_W( + inputs, o_afterW, WO); + + // Bias addition + add_bias( + o_afterW, o_afterBias, BO); + + // Hidden Candidate + multiply_U(hidden_state, o_hiddenCand, + RWO); + + // Vector addition + add_vectors(o_afterBias, o_hiddenCand, + o_afterAdd); + + // Activation + CONFIG_T::template activation_recr::activation(o_afterAdd, gate_o); + + //-----------Cell State Calculation + // Vector multiplication + multiply_vectors(gate_f, cell_state, + cell_act_multp); + + // Vector addition + add_vectors(gate_ic, cell_act_multp, + cell_act_add); + + //-----------Forget gate Calculation + // Activation + CONFIG_T::template activation::activation(cell_act_add, gate_forget); + + // Vector multiplication + multiply_vectors(gate_o, gate_forget, h); + +OUTPUT_WRITE_LOOP: + #pragma unroll + for (int x = (CONFIG_T::n_out - 1); x >= 0; x--) { + hidden_state_o[x] = h[x]; + cell_state_o[x] = cell_act_add[x]; + } +} + template void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out], const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out], @@ -835,8 +1005,8 @@ void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], h_T hidden_state_ } // Do LSTM - lstm_cell(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO, - BI, BF, BC, BO); + lstm_cell(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, + RWC, RWO, BI, BF, BC, BO); // Write result #pragma unroll diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h index 8d36a6e519..5a2783c2c1 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -74,8 +74,8 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG #pragma HLS ARRAY_PARTITION variable=s_actstate complete nnet::dense(data, tmpres, param, param_b); - nnet::dense(h_newstate, tmpres_state, param_r, - param_br); + nnet::dense(h_newstate, tmpres_state, param_r, + param_br); for (int iacc = 0; iacc < (3 * CONFIG_T::n_state); iacc++) { #pragma HLS UNROLL diff --git a/test/pytest/test_recurrent_brevitas.py b/test/pytest/test_recurrent_brevitas.py new file mode 100644 index 0000000000..78de67cc49 --- /dev/null +++ b/test/pytest/test_recurrent_brevitas.py @@ -0,0 +1,191 @@ +from pathlib import Path + +import brevitas.nn as qnn +import numpy as np +import pytest +import torch +from brevitas.quant import ( + Int8ActPerTensorFixedPoint, + Int8BiasPerTensorFixedPointInternalScaling, + Int8WeightPerTensorFixedPoint, +) +from torch import nn + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +test_root_path = Path(__file__).parent + + +class QuantRNNModel(nn.Module): + def __init__(self): + super().__init__() + self.rnn = qnn.QuantRNN( + input_size=10, + hidden_size=20, + bidirectional=False, + shared_input_hidden_weights=False, + batch_first=True, + weight_quant=Int8WeightPerTensorFixedPoint, + bias_quant=Int8BiasPerTensorFixedPointInternalScaling, + io_quant=Int8ActPerTensorFixedPoint, + gate_acc_quant=Int8ActPerTensorFixedPoint, + return_quant_tensor=True, + bias=True, + ) + + def forward(self, x, h0): + output, _ = self.rnn(x, (h0)) + return output + + +@pytest.mark.parametrize('backend', ['Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +def test_rnn(backend, io_type): + model = QuantRNNModel() + model.eval() + + X_input = torch.randn(1, 1, 10) + X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> + h0 = torch.randn(1, 1, 20) + h0 = np.round(h0 * 2**16) * 2**-16 + + pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().value.numpy() + + config = config_from_pytorch_model( + model, + [(None, 1, 10), (None, 1, 20)], + channels_last_conversion="off", + transpose_outputs=False, + default_precision='fixed<32,16>', + ) + output_dir = str(test_root_path / f'hls4mlprj_brevitas_rnn_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict([X_input.detach().numpy(), h0.detach().numpy()]), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far + + +class QuantLSTMModel(nn.Module): + def __init__(self): + super().__init__() + self.rnn = qnn.QuantLSTM( + 10, + 20, + num_layers=1, + batch_first=True, + bias=False, + weight_quant=Int8WeightPerTensorFixedPoint, + bias_quant=Int8BiasPerTensorFixedPointInternalScaling, + gate_acc_quant=Int8ActPerTensorFixedPoint, + return_quant_tensor=True, + sigmoid_quant=Int8ActPerTensorFixedPoint, + tanh_quant=Int8ActPerTensorFixedPoint, + shared_intra_layer_weight_quant=True, + shared_intra_layer_gate_acc_quant=True, + ) + + def forward(self, x, h0, c0): + output, (_, _) = self.rnn(x, (h0, c0)) + return output + + +class QuantLSTMModelStream(nn.Module): + def __init__(self): + super().__init__() + self.rnn = qnn.QuantLSTM( + 10, + 20, + num_layers=1, + batch_first=True, + bias=False, + weight_quant=Int8WeightPerTensorFixedPoint, + bias_quant=Int8BiasPerTensorFixedPointInternalScaling, + gate_acc_quant=Int8ActPerTensorFixedPoint, + return_quant_tensor=True, + sigmoid_quant=Int8ActPerTensorFixedPoint, + tanh_quant=Int8ActPerTensorFixedPoint, + shared_intra_layer_weight_quant=True, + shared_intra_layer_gate_acc_quant=True, + ) + + def forward(self, x): + output, (_, _) = self.rnn(x) + return output + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +def test_lstm(backend, io_type): + model = QuantLSTMModel() + model.eval() + + X_input = torch.randn(1, 1, 10) + X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> + h0 = torch.randn(1, 1, 20) + h0 = np.round(h0 * 2**16) * 2**-16 + c0 = torch.randn(1, 1, 20) + c0 = np.round(c0 * 2**16) * 2**-16 + + pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0), torch.tensor(c0)).detach().value.numpy() + + config = config_from_pytorch_model( + model, + [(None, 1, 10), (None, 1, 20), (None, 1, 20)], + channels_last_conversion="off", + transpose_outputs=False, + default_precision='fixed<32,16>', + ) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model( + model, + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + + hls_model.compile() + + hls_prediction = np.reshape( + hls_model.predict([X_input.detach().numpy(), h0.detach().numpy(), c0.detach().numpy()]), pytorch_prediction.shape + ) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_stream']) +def test_lstm_stream(backend, io_type): + if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"): + model = QuantLSTMModelStream() + model.eval() + + X_input = torch.randn(1, 1, 10) + X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> + + pytorch_prediction = model(torch.Tensor(X_input)).detach().value.numpy() + + config = config_from_pytorch_model( + model, [(None, 1, 10)], channels_last_conversion="off", transpose_outputs=False, default_precision='fixed<32,16>' + ) + output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model( + model, + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict(X_input.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far From 111b83ba1870a0680a02497666e60d1c6af46785 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 26 Mar 2025 14:08:27 -0400 Subject: [PATCH 43/47] don't break RNNs in non-brevitas workflows --- .../vivado/passes/recurrent_templates.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index b589188d45..e25f50651c 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -47,10 +47,26 @@ static const unsigned table_size = {table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; - typedef {act_t.name} table_t; + typedef {table_t.name} table_t; }};\n""" recr_activ_config_template = """struct {type}_config{index}_recr : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {table_t.name} table_t; +}};\n""" + +activ_config_template_brevitas = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + typedef {act_t.name} table_t; +}};\n""" + +recr_activ_config_template_brevitas = """struct {type}_config{index}_recr : nnet::activ_config {{ static const unsigned n_in = {n_in}; static const unsigned table_size = {table_size}; static const unsigned io_type = nnet::{iotype}; @@ -99,6 +115,8 @@ def __init__(self): self.template = recr_config_template self.act_template = activ_config_template self.recr_act_template = recr_activ_config_template + self.act_template_brevitas = activ_config_template_brevitas + self.recr_act_template_brevitas = recr_activ_config_template_brevitas self.mult1_template = recr_mult_config_template_1 self.mult2_template = recr_mult_config_template_2 @@ -144,8 +162,12 @@ def format(self, node): act_params['n_in'] = node.get_output_variable().shape[0] recr_act_params['n_in'] = node.get_output_variable().shape[0] * (n_recr_mult - 1) - act_config = self.act_template.format(**act_params) - recr_act_config = self.recr_act_template.format(**recr_act_params) + if 'act_t' in act_params.keys(): + act_config = self.act_template_brevitas.format(**act_params) + recr_act_config = self.recr_act_template_brevitas.format(**recr_act_params) + else: + act_config = self.act_template.format(**act_params) + recr_act_config = self.recr_act_template.format(**recr_act_params) mult_params1 = self._default_config_params(node) mult_params2 = self._default_config_params(node) From 82c6c0f1cd7391228c493810465aec3f286badbc Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 2 Apr 2025 09:14:22 -0400 Subject: [PATCH 44/47] using custom tracer file from extension API --- hls4ml/converters/pytorch/tracer.py | 16 ------------ hls4ml/converters/pytorch_to_hls.py | 2 +- .../optimizer/passes/brevitas_optimizer.py | 2 +- hls4ml/model/optimizer/passes/quant_opt.py | 4 +-- hls4ml/utils/torch.py | 26 +++++++++++++++++++ 5 files changed, 30 insertions(+), 20 deletions(-) delete mode 100644 hls4ml/converters/pytorch/tracer.py create mode 100644 hls4ml/utils/torch.py diff --git a/hls4ml/converters/pytorch/tracer.py b/hls4ml/converters/pytorch/tracer.py deleted file mode 100644 index babddad0c1..0000000000 --- a/hls4ml/converters/pytorch/tracer.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - - -class CustomFXTracer(torch.fx.Tracer): - - def is_leaf_module(self, m, module_qualified_name: str) -> bool: - """ - Custom Tracher class for hls4ml to define brevitas modules as leaf modules so they are not traced through by torch.FX - """ - import torch - - return ( - m.__module__.startswith("torch.nn") - or m.__module__.startswith("torch.ao.nn") - or m.__module__.startswith("brevitas.nn") - ) and not isinstance(m, torch.nn.Sequential) diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 99edfd9f2e..8eef503e74 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -175,7 +175,7 @@ def parse_pytorch_model(config, verbose=True): """ import torch - from hls4ml.converters.pytorch.tracer import CustomFXTracer + from hls4ml.utils.torch import CustomFXTracer # This is a list of dictionaries to hold all the layer info we need to generate HLS layer_list = [] diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py index 294e00a20d..a786b9c60f 100644 --- a/hls4ml/model/optimizer/passes/brevitas_optimizer.py +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -19,7 +19,7 @@ def transform(self, model, node): # See if Quant layer needs to be added for the output if 'output_quantization' in node.attributes.keys() and not len(node.attributes['output_quantization']) == 0: - + print(node.attributes['output_quantization']) attributes = {} input = node.name diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 538d77f4f2..bfb3a02314 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -126,11 +126,11 @@ def transform(self, model, node): if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): _, exp = np.frexp(next(iter(scale.flat))) integer = bitwidth + exp - 1 - + print(integer, bitwidth) precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) attributes = {'activation': 'linear', 'quantizer': quantizer} - + print(str(precision)) # update the configuration config = model.config.get_layer_config(node) prec_config = config.setdefault('Precision', {}) diff --git a/hls4ml/utils/torch.py b/hls4ml/utils/torch.py new file mode 100644 index 0000000000..bfd2c9f0ca --- /dev/null +++ b/hls4ml/utils/torch.py @@ -0,0 +1,26 @@ +import torch + + +class HLS4MLModule(torch.nn.Module): + """ + Custom PyTorch module class for hls4ml to define custom modules that shouldn't be traced through by torch.FX + """ + + pass + + +class CustomFXTracer(torch.fx.Tracer): + + def is_leaf_module(self, m, module_qualified_name: str) -> bool: + """ + Custom Tracer class for hls4ml to define Brevitas modules and custom modules as leaf modules so they are not traced + through by torch.FX + """ + import torch + + return ( + isinstance(m, HLS4MLModule) + or m.__module__.startswith('torch.nn') + or m.__module__.startswith('torch.ao.nn') + or m.__module__.startswith('brevitas.nn') + ) and not isinstance(m, torch.nn.Sequential) From b4dd15e442ddafc1af15c8ff5f3d4536881c2bba Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 9 Apr 2025 13:22:57 -0400 Subject: [PATCH 45/47] remove QuantLSTM for now --- .../vivado/passes/recurrent_templates.py | 25 +- hls4ml/converters/pytorch/recurrent.py | 259 ++++-------------- hls4ml/converters/pytorch_to_hls.py | 2 +- hls4ml/model/layers.py | 27 +- hls4ml/model/optimizer/passes/quant_opt.py | 3 +- .../firmware/nnet_utils/nnet_recurrent.h | 2 +- hls4ml/templates/vivado/build_prj.tcl | 1 - test/pytest/test_recurrent_brevitas.py | 121 -------- 8 files changed, 59 insertions(+), 381 deletions(-) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index e25f50651c..4ef742bfed 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -58,21 +58,6 @@ typedef {table_t.name} table_t; }};\n""" -activ_config_template_brevitas = """struct {type}_config{index} : nnet::activ_config {{ - static const unsigned n_in = {n_in}; - static const unsigned table_size = {table_size}; - static const unsigned io_type = nnet::{iotype}; - static const unsigned reuse_factor = {reuse}; - typedef {act_t.name} table_t; -}};\n""" - -recr_activ_config_template_brevitas = """struct {type}_config{index}_recr : nnet::activ_config {{ - static const unsigned n_in = {n_in}; - static const unsigned table_size = {table_size}; - static const unsigned io_type = nnet::{iotype}; - static const unsigned reuse_factor = {reuse}; - typedef {recurr_act_t.name} table_t; -}};\n""" # LSTM + GRU templates @@ -115,8 +100,6 @@ def __init__(self): self.template = recr_config_template self.act_template = activ_config_template self.recr_act_template = recr_activ_config_template - self.act_template_brevitas = activ_config_template_brevitas - self.recr_act_template_brevitas = recr_activ_config_template_brevitas self.mult1_template = recr_mult_config_template_1 self.mult2_template = recr_mult_config_template_2 @@ -162,12 +145,8 @@ def format(self, node): act_params['n_in'] = node.get_output_variable().shape[0] recr_act_params['n_in'] = node.get_output_variable().shape[0] * (n_recr_mult - 1) - if 'act_t' in act_params.keys(): - act_config = self.act_template_brevitas.format(**act_params) - recr_act_config = self.recr_act_template_brevitas.format(**recr_act_params) - else: - act_config = self.act_template.format(**act_params) - recr_act_config = self.recr_act_template.format(**recr_act_params) + act_config = self.act_template.format(**act_params) + recr_act_config = self.recr_act_template.format(**recr_act_params) mult_params1 = self._default_config_params(node) mult_params2 = self._default_config_params(node) diff --git a/hls4ml/converters/pytorch/recurrent.py b/hls4ml/converters/pytorch/recurrent.py index b1d810dd57..a4bac46d0c 100644 --- a/hls4ml/converters/pytorch/recurrent.py +++ b/hls4ml/converters/pytorch/recurrent.py @@ -76,7 +76,7 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas return layer, output_shape -quant_rnn_layers = ['QuantRNN', 'QuantLSTM'] # No QuantGRU in brevitas at this point +quant_rnn_layers = ['QuantRNN'] # QuantLSTM very complex, might come later. No QuantGRU in brevitas at this point @pytorch_handler(*quant_rnn_layers) @@ -124,39 +124,20 @@ def parse_quant_rnn_layer(operation, layer_name, input_names, input_shapes, node layer['n_out'] = class_object._modules['layers'][0][0].hidden_size - if 'LSTM' in operation: - LSTMObject = class_object._modules['layers'][0][0] - - input_weight = LSTMObject.input_gate_params.input_weight - forget_weight = LSTMObject.forget_gate_params.input_weight - cell_weight = LSTMObject.cell_gate_params.input_weight - output_weight = LSTMObject.output_gate_params.input_weight - - input_hidden_weight = LSTMObject.input_gate_params.hidden_weight - forget_hidden_weight = LSTMObject.forget_gate_params.hidden_weight - cell_hidden_weight = LSTMObject.cell_gate_params.hidden_weight - output_hidden_weight = LSTMObject.output_gate_params.hidden_weight - - width = int(input_weight.quant_weight().bit_width) - scale = input_weight.quant_weight().scale.detach().numpy() + RNNObject = class_object._modules['layers'][0][0] + if RNNObject.gate_params.input_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.input_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.input_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed mantissa, _ = np.frexp(scale) # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly # use the already quantized tensor from brevitas if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(width, scale) - combined_weight = np.concatenate( - ( - input_weight.quant_weight().detach().value.numpy(), - forget_weight.quant_weight().detach().value.numpy(), - cell_weight.quant_weight().detach().value.numpy(), - output_weight.quant_weight().detach().value.numpy(), - ), - axis=0, - ) - layer['weight_data'] = combined_weight + ap_fixed_params = convert_uaq_to_apfixed(width, float(RNNObject.gate_params.input_weight.quant_weight().scale)) + layer['weight_data'] = RNNObject.gate_params.input_weight.quant_weight().detach().value.numpy() layer['weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) ) else: raise Exception( @@ -164,216 +145,70 @@ def parse_quant_rnn_layer(operation, layer_name, input_names, input_shapes, node Please used QONNX instead.''' ) - width = int(input_hidden_weight.quant_weight().bit_width) - scale = input_hidden_weight.quant_weight().scale.detach().numpy() + if RNNObject.gate_params.hidden_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.hidden_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.hidden_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed mantissa, _ = np.frexp(scale) # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly # use the already quantized tensor from brevitas if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(width, scale) - - combined_hidden_weight = np.concatenate( - ( - input_hidden_weight.quant_weight().detach().value.numpy(), - forget_hidden_weight.quant_weight().detach().value.numpy(), - cell_hidden_weight.quant_weight().detach().value.numpy(), - output_hidden_weight.quant_weight().detach().value.numpy(), - ), - axis=0, - ) - - layer['recurrent_weight_data'] = combined_hidden_weight + ap_fixed_params = convert_uaq_to_apfixed(width, float(RNNObject.gate_params.hidden_weight.quant_weight().scale)) + layer['recurrent_weight_data'] = RNNObject.gate_params.hidden_weight.quant_weight().detach().value.numpy() layer['recurrent_weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) ) - - input_bias = LSTMObject.input_gate_params.quant_bias() - forget_bias = LSTMObject.forget_gate_params.quant_bias() - cell_bias = LSTMObject.cell_gate_params.quant_bias() - output_bias = LSTMObject.output_gate_params.quant_bias() - - if input_bias is not None: - width = int(input_bias.bit_width) - scale = input_bias.scale.detach().numpy() - mantissa, _ = np.frexp(scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(width, scale) - - combined_hidden_weight = np.concatenate( - ( - input_bias.detach().value.numpy(), - forget_bias.detach().value.numpy(), - cell_bias.detach().value.numpy(), - output_bias.detach().value.numpy(), - ), - axis=0, - ) - - layer['bias_data'] = combined_hidden_weight - layer['bias_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' - ) - else: - layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) - layer['bias_quantizer'] = layer['weight_quantizer'] - - layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) - layer['recurrent_bias_quantizer'] = layer['bias_quantizer'] - - acc_scale = LSTMObject.cell.forget_acc_quant.scale() - acc_bitwdith = int(LSTMObject.cell.forget_acc_quant.bit_width()) - mantissa, _ = np.frexp(acc_scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) - precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) - else: raise Exception( '''Non-power of 2 quantization of weights not supported when injecting brevitas models. Please used QONNX instead.''' ) - tanh_scale = LSTMObject.cell.cell_tanh_quant.scale() - tanh_bitwdith = int(LSTMObject.cell.cell_tanh_quant.bit_width()) - mantissa, _ = np.frexp(tanh_scale) + input_bias = RNNObject.gate_params.quant_bias() + if input_bias is not None: + width = int(input_bias.bit_width) + scale = input_bias.scale.detach().numpy() + mantissa, _ = np.frexp(scale) # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly # use the already quantized tensor from brevitas if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(tanh_bitwdith, tanh_scale) - precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - layer['act_t'] = NamedType(layer["name"] + '_act_t', precision) + ap_fixed_params = convert_uaq_to_apfixed(width, scale) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' + layer['bias_data'] = input_bias.detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) ) - - sigmoid_scale = LSTMObject.cell.cell_tanh_quant.scale() - sigmoid_bitwdith = int(LSTMObject.cell.cell_tanh_quant.bit_width()) - mantissa, _ = np.frexp(sigmoid_scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(sigmoid_bitwdith, sigmoid_scale) - precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - layer['recurr_act_t'] = NamedType(layer["name"] + '_recurr_act_t', precision) - else: raise Exception( '''Non-power of 2 quantization of weights not supported when injecting brevitas models. Please used QONNX instead.''' ) - if LSTMObject.cell.output_quant.is_quant_enabled: - layer = addQuantizationParameters(layer, LSTMObject.cell.output_quant, 'output', act=True) - layer = addQuantizationParameters(layer, LSTMObject.cell.output_quant, 'input', act=True) - else: + layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) + layer['bias_quantizer'] = layer['weight_quantizer'] - RNNObject = class_object._modules['layers'][0][0] - - if RNNObject.gate_params.input_weight.weight_quant.is_quant_enabled: - width = int(RNNObject.gate_params.input_weight.quant_weight().bit_width) - scale = RNNObject.gate_params.input_weight.quant_weight().scale.detach().numpy() - signed = RNNObject.gate_params.input_weight.quant_weight().signed - mantissa, _ = np.frexp(scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed( - width, float(RNNObject.gate_params.input_weight.quant_weight().scale) - ) - layer['weight_data'] = RNNObject.gate_params.input_weight.quant_weight().detach().value.numpy() - layer['weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) - ) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' - ) - - if RNNObject.gate_params.hidden_weight.weight_quant.is_quant_enabled: - width = int(RNNObject.gate_params.hidden_weight.quant_weight().bit_width) - scale = RNNObject.gate_params.hidden_weight.quant_weight().scale.detach().numpy() - signed = RNNObject.gate_params.input_weight.quant_weight().signed - mantissa, _ = np.frexp(scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed( - width, float(RNNObject.gate_params.hidden_weight.quant_weight().scale) - ) - layer['recurrent_weight_data'] = RNNObject.gate_params.hidden_weight.quant_weight().detach().value.numpy() - layer['recurrent_weight_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) - ) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' - ) - - input_bias = RNNObject.gate_params.quant_bias() - if input_bias is not None: - width = int(input_bias.bit_width) - scale = input_bias.scale.detach().numpy() - mantissa, _ = np.frexp(scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(width, scale) - - layer['bias_data'] = input_bias.detach().value.numpy() - layer['bias_quantizer'] = BrevitasQuantizer( - width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - ) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' - ) - else: - layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) - layer['bias_quantizer'] = layer['weight_quantizer'] - - layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) - layer['recurrent_bias_quantizer'] = layer['weight_quantizer'] + layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) + layer['recurrent_bias_quantizer'] = layer['weight_quantizer'] - acc_scale = RNNObject.cell.gate_acc_quant.scale() - acc_bitwdith = int(RNNObject.cell.gate_acc_quant.bit_width()) - mantissa, _ = np.frexp(acc_scale) - # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly - # use the already quantized tensor from brevitas - if mantissa == 0.5: - ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) - precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) - layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) + acc_scale = RNNObject.cell.gate_acc_quant.scale() + acc_bitwdith = int(RNNObject.cell.gate_acc_quant.bit_width()) + mantissa, _ = np.frexp(acc_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) - else: - raise Exception( - '''Non-power of 2 quantization of weights not supported when injecting brevitas models. - Please used QONNX instead.''' - ) - - if RNNObject.cell.output_quant.is_quant_enabled: - layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'output', act=True) - layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'input', act=True) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + if RNNObject.cell.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'output', act=True) + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'input', act=True) if layer['class_name'] == 'GRU': layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index 8eef503e74..de30b879ae 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -283,7 +283,7 @@ def parse_pytorch_model(config, verbose=True): # parse info from class object input_names = [inputs_map.get(str(i), str(i)) for i in node.args] - if pytorch_class in ['RNN', 'GRU', 'LSTM', 'QuantRNN', 'QuantLSTM']: + if pytorch_class in ['RNN', 'GRU', 'LSTM', 'QuantRNN']: input_shapes = [] input_names = [] for arg in node.args: diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index ab2cc5bdfc..50230b0a7e 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1293,7 +1293,6 @@ class SimpleRNN(Layer): TypeAttribute('weight'), TypeAttribute('bias'), TypeAttribute('recurrent_weight'), - TypeAttribute('accum_t'), ] def initialize(self): @@ -1317,19 +1316,15 @@ def initialize(self): ) # weights - self.add_weights(quantizer=self.get_attr('weight_quantizer')) + self.add_weights() # recurrent weights - self.add_weights_variable( - name='recurrent_weight', var_name='wr{index}', quantizer=self.get_attr('recurrent_weight_quantizer') - ) + self.add_weights_variable(name='recurrent_weight', var_name='wr{index}') # biases - self.add_weights_variable(name='bias', var_name='b{index}', quantizer=self.get_attr('bias_quantizer')) + self.add_weights_variable(name='bias', var_name='b{index}') if "pytorch" in self.attributes.keys(): - self.add_weights_variable( - name='recurrent_bias', var_name='br{index}', quantizer=self.get_attr('recurrent_bias_quantizer') - ) + self.add_weights_variable(name='recurrent_bias', var_name='br{index}') class LSTM(Layer): @@ -1350,7 +1345,6 @@ class LSTM(Layer): TypeAttribute('bias'), TypeAttribute('recurrent_weight'), TypeAttribute('recurrent_bias'), - TypeAttribute('accum_t'), ] def initialize(self): @@ -1374,24 +1368,17 @@ def initialize(self): ) # weights - self.add_weights(quantizer=self.get_attr('weight_quantizer')) + self.add_weights() # recurrent weights recurrent_weight = self.get_attr('recurrent_weight_data') - self.add_weights_variable( - name='recurrent_weight', - var_name='wr{index}', - data=recurrent_weight, - quantizer=self.get_attr('recurrent_weight_quantizer'), - ) + self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) # biases self.add_weights_variable(name='bias', var_name='b{index}', quantizer=self.get_attr('bias_quantizer')) if "pytorch" in self.attributes.keys(): - self.add_weights_variable( - name='recurrent_bias', var_name='br{index}', quantizer=self.get_attr('recurrent_bias_quantizer') - ) + self.add_weights_variable(name='recurrent_bias', var_name='br{index}') else: recurrent_bias = np.zeros(recurrent_weight.shape[1]) self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias) diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index bfb3a02314..08eb0049a0 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -126,11 +126,10 @@ def transform(self, model, node): if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): _, exp = np.frexp(next(iter(scale.flat))) integer = bitwidth + exp - 1 - print(integer, bitwidth) precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) attributes = {'activation': 'linear', 'quantizer': quantizer} - print(str(precision)) + # update the configuration config = model.config.get_layer_config(node) prec_config = config.setdefault('Precision', {}) diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h index 53b4d79618..0d9a263a48 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -14,7 +14,7 @@ namespace nnet { template void multiply_W(data_T input[N_IN], res_T out[N_OUT], const weight_t weight[N_IN * N_OUT]) { MULTIPLY_W_LOOP_I: - #pragma unroll, + #pragma unroll for (int i = 0; i < N_OUT; i++) { out[i] = 0; diff --git a/hls4ml/templates/vivado/build_prj.tcl b/hls4ml/templates/vivado/build_prj.tcl index 18a26c2c71..888c5f4c95 100644 --- a/hls4ml/templates/vivado/build_prj.tcl +++ b/hls4ml/templates/vivado/build_prj.tcl @@ -233,7 +233,6 @@ if {$opt(validation)} { if {$opt(export)} { puts "***** EXPORT IP *****" set time_start [clock clicks -milliseconds] - export_design -format ip_catalog -version $version set time_end [clock clicks -milliseconds] report_time "EXPORT IP" $time_start $time_end diff --git a/test/pytest/test_recurrent_brevitas.py b/test/pytest/test_recurrent_brevitas.py index 78de67cc49..fa6dd1c4e7 100644 --- a/test/pytest/test_recurrent_brevitas.py +++ b/test/pytest/test_recurrent_brevitas.py @@ -68,124 +68,3 @@ def test_rnn(backend, io_type): hls_prediction = np.reshape(hls_model.predict([X_input.detach().numpy(), h0.detach().numpy()]), pytorch_prediction.shape) np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far - - -class QuantLSTMModel(nn.Module): - def __init__(self): - super().__init__() - self.rnn = qnn.QuantLSTM( - 10, - 20, - num_layers=1, - batch_first=True, - bias=False, - weight_quant=Int8WeightPerTensorFixedPoint, - bias_quant=Int8BiasPerTensorFixedPointInternalScaling, - gate_acc_quant=Int8ActPerTensorFixedPoint, - return_quant_tensor=True, - sigmoid_quant=Int8ActPerTensorFixedPoint, - tanh_quant=Int8ActPerTensorFixedPoint, - shared_intra_layer_weight_quant=True, - shared_intra_layer_gate_acc_quant=True, - ) - - def forward(self, x, h0, c0): - output, (_, _) = self.rnn(x, (h0, c0)) - return output - - -class QuantLSTMModelStream(nn.Module): - def __init__(self): - super().__init__() - self.rnn = qnn.QuantLSTM( - 10, - 20, - num_layers=1, - batch_first=True, - bias=False, - weight_quant=Int8WeightPerTensorFixedPoint, - bias_quant=Int8BiasPerTensorFixedPointInternalScaling, - gate_acc_quant=Int8ActPerTensorFixedPoint, - return_quant_tensor=True, - sigmoid_quant=Int8ActPerTensorFixedPoint, - tanh_quant=Int8ActPerTensorFixedPoint, - shared_intra_layer_weight_quant=True, - shared_intra_layer_gate_acc_quant=True, - ) - - def forward(self, x): - output, (_, _) = self.rnn(x) - return output - - -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) -@pytest.mark.parametrize('io_type', ['io_parallel']) -def test_lstm(backend, io_type): - model = QuantLSTMModel() - model.eval() - - X_input = torch.randn(1, 1, 10) - X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> - h0 = torch.randn(1, 1, 20) - h0 = np.round(h0 * 2**16) * 2**-16 - c0 = torch.randn(1, 1, 20) - c0 = np.round(c0 * 2**16) * 2**-16 - - pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0), torch.tensor(c0)).detach().value.numpy() - - config = config_from_pytorch_model( - model, - [(None, 1, 10), (None, 1, 20), (None, 1, 20)], - channels_last_conversion="off", - transpose_outputs=False, - default_precision='fixed<32,16>', - ) - output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}') - - hls_model = convert_from_pytorch_model( - model, - hls_config=config, - output_dir=output_dir, - backend=backend, - io_type=io_type, - ) - - hls_model.compile() - - hls_prediction = np.reshape( - hls_model.predict([X_input.detach().numpy(), h0.detach().numpy(), c0.detach().numpy()]), pytorch_prediction.shape - ) - - np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far - - -@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) -@pytest.mark.parametrize('io_type', ['io_stream']) -def test_lstm_stream(backend, io_type): - if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"): - model = QuantLSTMModelStream() - model.eval() - - X_input = torch.randn(1, 1, 10) - X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> - - pytorch_prediction = model(torch.Tensor(X_input)).detach().value.numpy() - - config = config_from_pytorch_model( - model, [(None, 1, 10)], channels_last_conversion="off", transpose_outputs=False, default_precision='fixed<32,16>' - ) - output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}') - - hls_model = convert_from_pytorch_model( - model, - hls_config=config, - output_dir=output_dir, - backend=backend, - io_type=io_type, - ) - - hls_model.compile() - - hls_prediction = np.reshape(hls_model.predict(X_input.detach().numpy()), pytorch_prediction.shape) - - np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far From ed1c2c85cc2f9db20b25f0df82291e2fc7a3524f Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Wed, 9 Apr 2025 13:24:56 -0400 Subject: [PATCH 46/47] clean diff --- hls4ml/backends/vivado/passes/recurrent_templates.py | 1 - hls4ml/model/layers.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index 9afb4f2836..6934e82e4e 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -58,7 +58,6 @@ typedef {table_t.name} table_t; }};\n""" - # LSTM + GRU templates recr_config_template = """struct config{index} : nnet::{recr_type}_config {{ diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f28bdb8537..3d773838b1 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1375,7 +1375,7 @@ def initialize(self): self.add_weights_variable(name='recurrent_weight', var_name='wr{index}', data=recurrent_weight) # biases - self.add_weights_variable(name='bias', var_name='b{index}', quantizer=self.get_attr('bias_quantizer')) + self.add_weights_variable(name='bias', var_name='b{index}') if "pytorch" in self.attributes.keys(): self.add_weights_variable(name='recurrent_bias', var_name='br{index}') From 6f1c24f6fb044d462a16f215e51b1c64d0d91428 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Mon, 14 Apr 2025 13:07:49 -0400 Subject: [PATCH 47/47] fix failing pytests --- .../firmware/nnet_utils/nnet_recurrent.h | 2 +- .../firmware/nnet_utils/nnet_recurrent.h | 174 +----------------- test/pytest/test_brevitas_parsing.py | 4 +- 3 files changed, 5 insertions(+), 175 deletions(-) diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h index 678161006f..3367213167 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h @@ -493,7 +493,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T &hin, res_T &re } // Do SimpleRNN - simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); + simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); // Write result #pragma unroll diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h index 0d9a263a48..794e46972e 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -58,14 +58,6 @@ template void multiply_vectors(data_T in1[N], } } -template void multiply_vectors(data_T in1[N], s_T in2[N], res_T out[N]) { -MULTIPLY_VECT_LOOP: - #pragma unroll - for (int i = 0; i < N; i++) { - out[i] = in1[i] * in2[i]; - } -} - template void add_vectors(data_T in1[N], data_T in2[N], res_T out[N]) { ADD_VECTOR_LOOP: #pragma unroll @@ -726,168 +718,6 @@ void lstm_cell(data_T inputs[CONFIG_T::n_in], res_T hidden_state[CONFIG_T::n_out } } -template -void lstm_cell(data_T inputs[CONFIG_T::n_in], h_T hidden_state[CONFIG_T::n_out], h_T hidden_state_o[CONFIG_T::n_out], - s_T cell_state[CONFIG_T::n_out], s_T cell_state_o[CONFIG_T::n_out], - const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out], - const typename CONFIG_T::weight_t WF[CONFIG_T::n_in * CONFIG_T::n_out], - const typename CONFIG_T::weight_t WC[CONFIG_T::n_in * CONFIG_T::n_out], - const typename CONFIG_T::weight_t WO[CONFIG_T::n_in * CONFIG_T::n_out], - const typename CONFIG_T::weight_t RWI[CONFIG_T::n_out * CONFIG_T::n_out], - const typename CONFIG_T::weight_t RWF[CONFIG_T::n_out * CONFIG_T::n_out], - const typename CONFIG_T::weight_t RWC[CONFIG_T::n_out * CONFIG_T::n_out], - const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out], - const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out], - const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) { - - // Internals definitions - typename CONFIG_T::accum_t i_afterW[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t i_afterBias[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t c_afterW[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t c_afterBias[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t o_afterW[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t o_afterBias[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t f_afterW[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t f_afterBias[CONFIG_T::n_out] hls_register; - - // Hidden state Gate candidates, intermediate variables - typename CONFIG_T::accum_t i_hiddenCand[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t f_hiddenCand[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t c_hiddenCand[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t o_hiddenCand[CONFIG_T::n_out] hls_register; - - // After addition, intermediate variables - typename CONFIG_T::accum_t i_afterAdd[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t f_afterAdd[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t c_afterAdd[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t o_afterAdd[CONFIG_T::n_out] hls_register; - - // Gate outputs - typename CONFIG_T::accum_t gate_i[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t gate_f[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t gate_c[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t gate_o[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t gate_ic[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t gate_forget[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t h[CONFIG_T::n_out] hls_register; - - // Intermediate variable cell calculation - typename CONFIG_T::accum_t cell_act_multp[CONFIG_T::n_out] hls_register; - typename CONFIG_T::accum_t cell_act_add[CONFIG_T::n_out] hls_register; - - //-----------Gate I Calculations - // Weight multiplication - multiply_W( - inputs, i_afterW, WI); - - // Bias addition - add_bias( - i_afterW, i_afterBias, BI); - - // Hidden Candidate - multiply_U(hidden_state, i_hiddenCand, - RWI); - - // Vector addition - add_vectors(i_afterBias, i_hiddenCand, - i_afterAdd); - - // Activation - CONFIG_T::template activation_recr::activation(i_afterAdd, gate_i); - - //-----------Gate F Calculations - // Weight multiplication - multiply_W( - inputs, f_afterW, WF); - - // Bias addition - add_bias( - f_afterW, f_afterBias, BF); - - // Hidden Candidate - multiply_U(hidden_state, f_hiddenCand, - RWF); - - // Vector addition - add_vectors(f_afterBias, f_hiddenCand, - f_afterAdd); - - // Activation - CONFIG_T::template activation_recr::activation(f_afterAdd, gate_f); - - //-----------Gate C Calculations - // Weight multiplication - multiply_W( - inputs, c_afterW, WC); - - // Bias addition - add_bias( - c_afterW, c_afterBias, BC); - - // Hidden Candidate - multiply_U(hidden_state, c_hiddenCand, - RWC); - - // Vector addition - add_vectors(c_afterBias, c_hiddenCand, - c_afterAdd); - - // Activation - CONFIG_T::template activation::activation(c_afterAdd, gate_c); - - //-----------gate I and C multiply - // Vector multiplication - multiply_vectors(gate_i, gate_c, gate_ic); - - //-----------Gate O Calculations - // Weight multiplication - multiply_W( - inputs, o_afterW, WO); - - // Bias addition - add_bias( - o_afterW, o_afterBias, BO); - - // Hidden Candidate - multiply_U(hidden_state, o_hiddenCand, - RWO); - - // Vector addition - add_vectors(o_afterBias, o_hiddenCand, - o_afterAdd); - - // Activation - CONFIG_T::template activation_recr::activation(o_afterAdd, gate_o); - - //-----------Cell State Calculation - // Vector multiplication - multiply_vectors(gate_f, cell_state, - cell_act_multp); - - // Vector addition - add_vectors(gate_ic, cell_act_multp, - cell_act_add); - - //-----------Forget gate Calculation - // Activation - CONFIG_T::template activation::activation(cell_act_add, gate_forget); - - // Vector multiplication - multiply_vectors(gate_o, gate_forget, h); - -OUTPUT_WRITE_LOOP: - #pragma unroll - for (int x = (CONFIG_T::n_out - 1); x >= 0; x--) { - hidden_state_o[x] = h[x]; - cell_state_o[x] = cell_act_add[x]; - } -} - template void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out], const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out], @@ -1005,8 +835,8 @@ void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], h_T hidden_state_ } // Do LSTM - lstm_cell(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, - RWC, RWO, BI, BF, BC, BO); + lstm_cell(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO, + BI, BF, BC, BO); // Write result #pragma unroll diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py index 293620e01c..c8989aaed2 100644 --- a/test/pytest/test_brevitas_parsing.py +++ b/test/pytest/test_brevitas_parsing.py @@ -104,7 +104,7 @@ def test_quantconv1d(backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') - from hls4ml.converters.pytorch.tracer import CustomFXTracer + from hls4ml.utils.torch import CustomFXTracer tracer = CustomFXTracer() traced_model = tracer.trace(model) @@ -165,7 +165,7 @@ def test_quantconv2d(backend, io_type): output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}') - from hls4ml.converters.pytorch.tracer import CustomFXTracer + from hls4ml.utils.torch import CustomFXTracer tracer = CustomFXTracer() traced_model = tracer.trace(model)