diff --git a/isthmus/build.gradle.kts b/isthmus/build.gradle.kts index b77c1abb9..d7d784bff 100644 --- a/isthmus/build.gradle.kts +++ b/isthmus/build.gradle.kts @@ -7,9 +7,18 @@ plugins { id("com.diffplug.spotless") version "6.19.0" id("com.github.johnrengelman.shadow") version "8.1.1" id("com.google.protobuf") version "0.9.4" + id("com.adarshr.test-logger") version "4.0.0" signing } +// Useful test logger when debugging, will output stdout/stderr to console +// saves time launching the HTML test reports +testlogger { + showStandardStreams = false + showPassedStandardStreams = true + showFailedStandardStreams = true +} + publishing { publications { create("maven-publish") { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 45b42f814..ebc7b8bf0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -1,6 +1,6 @@ package io.substrait.isthmus; -import static io.substrait.isthmus.SqlToSubstrait.EXTENSION_COLLECTION; +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import com.google.common.collect.ImmutableList; import io.substrait.expression.Expression; @@ -44,6 +44,7 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSlot; @@ -51,6 +52,7 @@ import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Holder; /** * RelVisitor to convert Substrait Rel plan to Calcite RelNode plan. Unsupported Rel node will call @@ -139,8 +141,12 @@ public static RelNode convert( @Override public RelNode visit(Filter filter) throws RuntimeException { RelNode input = filter.getInput().accept(this); + final Holder v = Holder.empty(); + expressionRexConverter.addCorrelVariable(v); + + RelBuilder r1 = relBuilder.push(input).variable(v::set); RexNode filterCondition = filter.getCondition().accept(expressionRexConverter); - RelNode node = relBuilder.push(input).filter(filterCondition).build(); + RelNode node = r1.filter(List.of(v.get().id), filterCondition).build(); return applyRemap(node, filter.getRemap()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index f7dd76f6e..f70fd7cac 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -22,8 +22,10 @@ import io.substrait.type.Type; import io.substrait.util.DecimalUtil; import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -34,6 +36,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; @@ -46,6 +49,7 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.util.Holder; import org.apache.calcite.util.TimeString; import org.apache.calcite.util.TimestampString; @@ -468,7 +472,7 @@ public RexNode visit(Expression.InPredicate expr) throws RuntimeException { return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); } - static class ToRexWindowBound + static final class ToRexWindowBound implements WindowBound.WindowBoundVisitor { static RexWindowBound lowerBound(RexBuilder rexBuilder, WindowBound bound) { @@ -538,22 +542,39 @@ public RexNode visit(Expression.Cast expr) throws RuntimeException { @Override public RexNode visit(FieldReference expr) throws RuntimeException { if (expr.isSimpleRootReference()) { + Optional outerref = expr.outerReferenceStepsOut(); var segment = expr.segments().get(0); + if (outerref.isPresent()) { + if (segment instanceof FieldReference.StructField) { + FieldReference.StructField f = (FieldReference.StructField) segment; + var node = referenceRelList.get(outerref.get() - 1).get(); - RexInputRef rexInputRef; - if (segment instanceof FieldReference.StructField f) { - rexInputRef = - new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + return rexBuilder.makeFieldAccess(node, f.offset()); + } } else { - throw new IllegalArgumentException("Unhandled type: " + segment); + RexInputRef rexInputRef; + if (segment instanceof FieldReference.StructField f) { + rexInputRef = + new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType())); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } + return rexInputRef; } - - return rexInputRef; } - return visitFallback(expr); } + protected List> referenceRelList = new ArrayList<>(); + + public void addCorrelVariable(Holder correlVaraible) { + referenceRelList.add(correlVaraible); + } + + public Holder getOuterRef(int i) { + return referenceRelList.get(i); + } + @Override public RexNode visitFallback(Expression expr) { throw new UnsupportedOperationException( diff --git a/isthmus/src/test/java/io/substrait/isthmus/TestExtendedTpchQuery.java b/isthmus/src/test/java/io/substrait/isthmus/TestExtendedTpchQuery.java new file mode 100644 index 000000000..4def49c74 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/TestExtendedTpchQuery.java @@ -0,0 +1,67 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.Plan; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Additional queries based around the style and schema of the tpc-h set Validating that the + * conversions can operate without exceptions + */ +@TestMethodOrder(OrderAnnotation.class) +@TestInstance(Lifecycle.PER_CLASS) +public class TestExtendedTpchQuery extends PlanTestBase { + + private Map allPlans = new HashMap<>(); + + // Keep list of the known test failures + // The `fromSubstrait` also assumes the to substrait worked as well + public static final List toSubstraitKnownFails = List.of(); + public static final List fromSubstraitKnownFails = List.of(); + + @ParameterizedTest + @Order(1) + @ValueSource(ints = {1}) + public void extendedTpchToSubstrait(int query) throws Exception { + assumeFalse(toSubstraitKnownFails.contains(query)); + + SqlToSubstrait s = new SqlToSubstrait(); + String[] values = asString("tpch/schema.sql").split(";"); + var creates = + Arrays.stream(values) + .filter(t -> !t.trim().isBlank()) + .collect(java.util.stream.Collectors.toList()); + Plan protoPlan = s.execute(asString(String.format("tpch/extended/%02d.sql", query)), creates); + + allPlans.put(query, protoPlan); + } + + @ParameterizedTest + @Order(2) + @ValueSource(ints = {1}) + public void extendedTpchFromSubstrait(int query) throws Exception { + assumeFalse(fromSubstraitKnownFails.contains(query)); + assumeTrue(allPlans.containsKey(query)); + + Plan possible = allPlans.get(query); + + io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible); + SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true); + System.out.println(SubstraitToSql.toSql(relRoot)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/TestTpcdsQuery.java b/isthmus/src/test/java/io/substrait/isthmus/TestTpcdsQuery.java new file mode 100644 index 000000000..13a91c1d8 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/TestTpcdsQuery.java @@ -0,0 +1,88 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.Plan; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.adapter.tpcds.TpcdsSchema; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Updated TPC-H test to convert SQL to Substrait and replay those plans back to SQL Validating that + * the conversions can operate without exceptions + */ +@TestMethodOrder(OrderAnnotation.class) +@TestInstance(Lifecycle.PER_CLASS) +public class TestTpcdsQuery extends PlanTestBase { + + private List> allPlans; + + @BeforeAll + public void setup() { + allPlans = new ArrayList>(); + for (int i = 1; i < 101; i++) { + allPlans.add(Optional.empty()); + } + } + + // Keep list of the known test failures + // The `fromSubstrait` also assumes the to substrait worked as well + public static final List toSubstraitKnownFails = + List.of(5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, 84, 86, 89, 91, 98); + public static final List fromSubstraitKnownFails = List.of(1, 8, 30, 49, 67, 81); + + @ParameterizedTest + @Order(1) + @ValueSource( + ints = { + 1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30, + 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58, + 59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87, + 88, 90, 92, 93, 94, 95, 96, 97, 99, 2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, + 84, 86, 89, 91, 98, + }) + public void tpcdsSuccess(int query) throws Exception { + assumeFalse(toSubstraitKnownFails.contains(query)); + + SqlToSubstrait s = new SqlToSubstrait(); + TpcdsSchema schema = new TpcdsSchema(1.0); + String sql = asString(String.format("tpcds/queries/%02d.sql", query)); + Plan protoPlan = s.execute(sql, "tpcds", schema); + allPlans.set(query, Optional.of(protoPlan)); + } + + @ParameterizedTest + @Order(2) + @ValueSource( + ints = { + 1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30, + 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58, + 59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87, + 88, 90, 92, 93, 94, 95, 96, 97, 99, 2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, + 84, 86, 89, 91, 98, + }) + public void tpcdsFromSubstrait(int query) throws Exception { + + assumeFalse(fromSubstraitKnownFails.contains(query)); + assumeTrue(allPlans.get(query).isPresent()); + + Optional possible = allPlans.get(query); + + io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible.get()); + SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true); + System.out.println(SubstraitToSql.toSql(relRoot)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/TestTpchQuery.java b/isthmus/src/test/java/io/substrait/isthmus/TestTpchQuery.java new file mode 100644 index 000000000..89177d022 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/TestTpchQuery.java @@ -0,0 +1,69 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.Plan; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelNode; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Updated TPC-H test to convert SQL to Substrait and replay those plans back to SQL Validating that + * the conversions can operate without exceptions + */ +@TestMethodOrder(OrderAnnotation.class) +@TestInstance(Lifecycle.PER_CLASS) +public class TestTpchQuery extends PlanTestBase { + + private Map allPlans = new HashMap<>(); + + // Keep list of the known test failures + // The `fromSubstrait` also assumes the to substrait worked as well + public static final List toSubstraitKnownFails = List.of(22); + public static final List fromSubstraitKnownFails = List.of(7, 8, 9); + + @ParameterizedTest + @Order(1) + @ValueSource( + ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22}) + public void tpchToSubstrait(int query) throws Exception { + assumeFalse(toSubstraitKnownFails.contains(query)); + + SqlToSubstrait s = new SqlToSubstrait(); + String[] values = asString("tpch/schema.sql").split(";"); + var creates = + Arrays.stream(values) + .filter(t -> !t.trim().isBlank()) + .collect(java.util.stream.Collectors.toList()); + Plan protoPlan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates); + + allPlans.put(query, protoPlan); + } + + @ParameterizedTest + @Order(2) + @ValueSource( + ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22}) + public void tpchFromSubstrait(int query) throws Exception { + assumeFalse(fromSubstraitKnownFails.contains(query)); + assumeTrue(allPlans.containsKey(query)); + + Plan possible = allPlans.get(query); + + io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible); + SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true); + System.out.println(SubstraitToSql.toSql(relRoot)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryNoValidation.java b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryNoValidation.java index a53fda3d0..f60b43097 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryNoValidation.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryNoValidation.java @@ -2,9 +2,12 @@ import com.google.protobuf.util.JsonFormat; import java.util.Arrays; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.TestMethodOrder; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +@TestMethodOrder(OrderAnnotation.class) public class TpchQueryNoValidation extends PlanTestBase { @ParameterizedTest @@ -18,7 +21,7 @@ public void tpch(int query) throws Exception { Arrays.stream(values) .filter(t -> !t.trim().isBlank()) .collect(java.util.stream.Collectors.toList()); - var plan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates); - System.out.println(JsonFormat.printer().print(plan)); + var protoPlan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates); + System.out.println(JsonFormat.printer().print(protoPlan)); } } diff --git a/isthmus/src/test/resources/tpch/extended/01.sql b/isthmus/src/test/resources/tpch/extended/01.sql new file mode 100644 index 000000000..7564930cf --- /dev/null +++ b/isthmus/src/test/resources/tpch/extended/01.sql @@ -0,0 +1,27 @@ +select + c1.c_name, + o1.o_orderstatus, + o1.o_totalprice +from + customer c1, + orders o1 +where + o1.o_custkey = c1.c_custkey + and o1.o_totalprice > ( + select + avg(o_totalprice) + from + orders o2 + where + o2.o_totalprice < c1.c_acctbal + and o2.o_orderpriority = c1.c_phone + and o2.o_totalprice > ( + select + avg(c3.c_acctbal) + from + customer c3 + where + c1.c_custkey = o2.o_custkey + and c3.c_address = o2.o_clerk + ) + ); \ No newline at end of file