diff --git a/src/matlab2cpp/node/backend.py b/src/matlab2cpp/node/backend.py index b6e88ff..34419a5 100755 --- a/src/matlab2cpp/node/backend.py +++ b/src/matlab2cpp/node/backend.py @@ -183,9 +183,11 @@ def auxillary(node, type, convert): # Return value aux_var = matlab2cpp.collection.Var(assign, var) + aux_var.create_declare() + aux_var.declare.type = type + aux_var.declare.backend = type aux_var.type = type aux_var.backend = type - aux_var.create_declare() if convert: rhs = matlab2cpp.collection.Get(assign, "_conv_to") @@ -194,7 +196,10 @@ def auxillary(node, type, convert): rhs = assign swap_var = matlab2cpp.collection.Var(rhs, var) + swap_var.type = type + swap_var.backend = type swap_var.declare.type = type + swap_var.declare.backend = type # Place Assign correctly in Block i = block.children.index(line) @@ -353,7 +358,6 @@ def create_declare(node): var.type="struct" return matlab2cpp.collection.Var(struct, name=value) - parent = struct else: parent = node.func[0] diff --git a/src/matlab2cpp/node/frontend.py b/src/matlab2cpp/node/frontend.py index d6c4020..63e6252 100755 --- a/src/matlab2cpp/node/frontend.py +++ b/src/matlab2cpp/node/frontend.py @@ -301,7 +301,7 @@ def auxiliary(self, type=None, convert=False): 1 4| | | | Int int int 1 1| Statement code_block TYPE 1 1| | Plus expression irowvec - 1 1| | | Var unknown irowvec _aux_irowvec_1 + 1 1| | | Var irowvec irowvec _aux_irowvec_1 1 7| | | Int int int """ return backend.auxillary(self, type, convert) diff --git a/src/matlab2cpp/rules/_reserved.py b/src/matlab2cpp/rules/_reserved.py index 75a0236..63cdaee 100644 --- a/src/matlab2cpp/rules/_reserved.py +++ b/src/matlab2cpp/rules/_reserved.py @@ -996,14 +996,14 @@ def Get_sum(node): >>> print(matlab2cpp.qscript("a=[-1,2;3,4]; b = sum(a, 1)", suggest=True)) sword _a [] = {-1, 2, 3, 4} ; a = arma::strans(imat(_a, 2, 2, false)) ; - b = arma::sum(arma:vectorize(a), 0) ; + b = arma::sum(arma::vectorise(a), 0) ; >>> print(matlab2cpp.qscript("a=[1., 2.; 3., 4.]; b = sum(a(:))", suggest=True)) double _a [] = {1., 2., 3., 4.} ; a = arma::strans(mat(_a, 2, 2, false)) ; - b = double(arma::as_scalar(arma::sum(arma:vectorize(a(span(0, a.n_rows-1)))))) ; + b = double(arma::as_scalar(arma::sum(arma::vectorise(a)))) ; >>> print(matlab2cpp.qscript("a=rand(9, 9, 9); b = sum(a(:))", suggest=True)) a = arma::randu(9, 9, 9) ; - b = double(arma::as_scalar(arma::sum(arma:vectorize(a(span(0, a.n_rows-1)))))) ; + b = double(arma::as_scalar(arma::sum(arma::vectorise(a)))) ; """ arg = node[0] # unknown input @@ -1012,7 +1012,7 @@ def Get_sum(node): return "arma::sum(", ", ", ")" if arg.dim > 2: - arg = "arma:vectorize(%(0)s)" + arg = "arma::vectorise(%(0)s)" else: arg = "%(0)s" diff --git a/src/matlab2cpp/rules/armadillo.py b/src/matlab2cpp/rules/armadillo.py index 512cdd0..6909a26 100644 --- a/src/matlab2cpp/rules/armadillo.py +++ b/src/matlab2cpp/rules/armadillo.py @@ -20,7 +20,7 @@ def configure_arg(node, index): >>> print(matlab2cpp.qscript('x=[1,2]; x(:)')) sword _x [] = {1, 2} ; x = irowvec(_x, 2, false) ; - x(span(0, x.n_rows-1)) ; + x ; >>> print(matlab2cpp.qscript('x=[1,2]; x(1)')) sword _x [] = {1, 2} ; x = irowvec(_x, 2, false) ; @@ -42,7 +42,6 @@ def configure_arg(node, index): x = irowvec(_x, 2, false) ; x(arma::trans(x)-1) ; """ - out = "%(" + str(index) + ")s" # the full range ':' @@ -71,7 +70,7 @@ def configure_arg(node, index): # undefined type elif node.type == "TYPE": - return out, -1 + return out , -1 # float point scalar elif node.mem > 1 and node.dim == 0: @@ -103,10 +102,9 @@ def configure_arg(node, index): else: dim = 1 - - if len(node) > 0 and node[0].cls == "Paren": + if (len(node) > 0) and (node[0].cls == "Paren"): pass - elif node.cls not in ["Colon", "Paren"]: + elif node.cls not in ("Colon", "Paren"): out = out + "-1" return out, dim diff --git a/src/matlab2cpp/rules/mat.py b/src/matlab2cpp/rules/mat.py index d9d8886..b8bb268 100755 --- a/src/matlab2cpp/rules/mat.py +++ b/src/matlab2cpp/rules/mat.py @@ -116,9 +116,8 @@ def Set(node): a(n) = b """ - # wrong number of argumets - if len(node) not in (1,2): + if len(node) not in (1, 2): if not len(node): node.error("Zero arguments in a matrix set") @@ -195,17 +194,27 @@ def Set(node): # uvec + scalar elif dim0 > 0 and dim1 == 0: - index = node[0].str.index('(') - return "%(name)s(" + "m2cpp::span" + node[0].str[index:] + \ - ", m2cpp::span(" + arg1 + ", " + arg1 + "))" + + if "(" in node[0].str: + index = node[0].str.index('(') + lhs = node[0].str[index:] + return ( + "%(name)s(m2cpp::span" + lhs + + ", m2cpp::span(" + arg1 + ", " + arg1 + "))" + ) + return "%(name)s(" + arg0 + ", " + arg1 + ")" #return "%(name)s.row(" + arg0 + ").cols(" + arg1 + ")" # uvec + uvec if dim0 > 0 and dim1 > 0: - a0 = node[0].str.replace("arma::span", "m2cpp::span") - a1 = node[1].str.replace("arma::span", "m2cpp::span") - - return "%(name)s(" + a0 + ", " + a1 + ")" + arg0 = node[0].str.replace("arma::span", "m2cpp::span") + arg1 = node[1].str.replace("arma::span", "m2cpp::span") + if arg0.startswith("_aux_"): + arg0 = arg0 + "-1" + if arg1.startswith("_aux_"): + arg1 = arg1 + "-1" + + return "%(name)s(" + arg0 + ", " + arg1 + ")" return "%(name)s(" + arg0 + ", " + arg1 + ")" diff --git a/src/matlab2cpp/tree/variables.py b/src/matlab2cpp/tree/variables.py index 3229790..18c48ad 100755 --- a/src/matlab2cpp/tree/variables.py +++ b/src/matlab2cpp/tree/variables.py @@ -397,8 +397,14 @@ def variable(self, parent, cur): print("%-20s" % "variables.variable", end="") print(repr(self.code[cur:end+1])) - node = collection.Get(parent, name, cur=cur, - code=self.code[cur:end+1]) + # 'A(:)' is equivalent to 'A': + if self.code[end-2:end+1] == "(:)": + node = collection.Var( + parent, name, cur=cur, code=self.code[cur:end+1]) + + else: + node = collection.Get( + parent, name, cur=cur, code=self.code[cur:end+1]) last = self.create_list(node, k) cur = last diff --git a/test/test_conversion.py b/test/test_fx_decon.py similarity index 100% rename from test/test_conversion.py rename to test/test_fx_decon.py diff --git a/test/test_simple_assignment.py b/test/test_snippet_conversion.py similarity index 96% rename from test/test_simple_assignment.py rename to test/test_snippet_conversion.py index 8c9a821..f00887a 100644 --- a/test/test_simple_assignment.py +++ b/test/test_snippet_conversion.py @@ -1,7 +1,7 @@ """Test simple assignment.""" import pytest -from matlab2cpp import qcpp, qhpp +from matlab2cpp import qcpp, qhpp, qtree @pytest.fixture(params=[