From ae6a0e68873d8da298905356a3f5cb87cb3466b1 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 27 May 2025 11:39:32 -0500 Subject: [PATCH] Add conversion of tree rules to a ROBDD Tree-based endpoint rules can be converted to a reduced ordered binary decision diagram, or ROBDD, allowing for a much more compact representation of endpoints and more optimal evaluation with no duplicated conditions for any given path. --- smithy-rules-engine/build.gradle.kts | 2 + .../rulesengine/analysis/HashConsGraph.java | 279 ++++++++++++ .../rulesengine/language/EndpointRuleSet.java | 96 +++- .../language/syntax/bdd/RulesBdd.java | 425 ++++++++++++++++++ .../syntax/bdd/RulesBddCondition.java | 168 +++++++ .../language/syntax/bdd/RulesBddNode.java | 81 ++++ .../language/syntax/rule/EndpointRule.java | 2 +- .../analysis/HashConsGraphTest.java | 70 +++ 8 files changed, 1107 insertions(+), 16 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/HashConsGraph.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBdd.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddCondition.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddNode.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/HashConsGraphTest.java diff --git a/smithy-rules-engine/build.gradle.kts b/smithy-rules-engine/build.gradle.kts index 8bf1bdd248d..ef0e0ba385e 100644 --- a/smithy-rules-engine/build.gradle.kts +++ b/smithy-rules-engine/build.gradle.kts @@ -15,4 +15,6 @@ dependencies { api(project(":smithy-model")) api(project(":smithy-utils")) api(project(":smithy-jmespath")) + + testImplementation(project(":smithy-aws-endpoints")) } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/HashConsGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/HashConsGraph.java new file mode 100644 index 00000000000..566ca82c28a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/analysis/HashConsGraph.java @@ -0,0 +1,279 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rulesengine.analysis; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.bdd.RulesBdd; +import software.amazon.smithy.rulesengine.language.syntax.bdd.RulesBddCondition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Converts a {@link EndpointRuleSet} into a list of unique paths, a tree of conditions and leaves, and a BDD. + */ +public final class HashConsGraph { + + // Endpoint ruleset to optimize. + private final EndpointRuleSet ruleSet; + + // Provides a hash of endpoints/errors to their index. + private final Map resultHashCons = new HashMap<>(); + + // Provides a hash of conditions to their index. + private final Map conditionHashCons = new HashMap<>(); + + // Provides a mapping of originally defined conditions to their canonicalized conditions. + // (e.g., moving variables before literals in commutative functions). + private final Map canonicalizedConditions = new HashMap<>(); + + // A flattened list of unique leaves. + private final List results = new ArrayList<>(); + + // A flattened list of unique conditions + private final List conditions = new ArrayList<>(); + + // A flattened set of unique condition paths to leaves, sorted based on desired complexity order. + private final Set paths = new LinkedHashSet<>(); + + public HashConsGraph(EndpointRuleSet ruleSet) { + this.ruleSet = ruleSet; + hashConsConditions(); + + // Now build up paths and refer to the hash-consed conditions. + for (Rule rule : ruleSet.getRules()) { + crawlRules(rule, new LinkedHashSet<>()); + } + } + + // First create a global ordering of conditions. The ordering of conditions is the primary way to influence + // the resulting node tables of a BDD. + // 1. Simplest conditions come first (e.g., isset, booleanEquals, etc.). We build this up by gathering all + // the stateless conditions and sorting them by complexity order so that simplest checks happen earlier. + // 2. Stateful conditions come after, and they must appear in a dependency ordering (i.e., if a condition + // depends on a previous condition to bind a variable, then it must come after its dependency). This is + // done by iterating over paths and add stateful conditions, in path order, to a LinkedHashSet of + // conditions, giving us a hash-consed but ordered set of all stateful conditions across all paths. + private void hashConsConditions() { + Set statelessCondition = new LinkedHashSet<>(); + Set statefulConditions = new LinkedHashSet<>(); + for (Rule rule : ruleSet.getRules()) { + crawlConditions(rule, statelessCondition, statefulConditions); + } + + // Sort the stateless conditions by complexity order, maintaining insertion order when equal. + List sortedStatelessConditions = new ArrayList<>(statelessCondition); + sortedStatelessConditions.sort(Comparator.comparingInt(RulesBddCondition::getComplexity)); + + // Now build up the hash-consed map of conditions to their integer position in a sorted array of RuleCondition. + hashConsCollectedConditions(sortedStatelessConditions); + hashConsCollectedConditions(statefulConditions); + } + + private void hashConsCollectedConditions(Collection ruleConditions) { + for (RulesBddCondition ruleCondition : ruleConditions) { + conditionHashCons.put(ruleCondition.getCondition(), conditions.size()); + conditions.add(ruleCondition); + } + } + + public List getPaths() { + return new ArrayList<>(paths); + } + + public List getConditions() { + return new ArrayList<>(conditions); + } + + public List getResults() { + return new ArrayList<>(results); + } + + public EndpointRuleSet getRuleSet() { + return ruleSet; + } + + public RulesBdd getBdd() { + return RulesBdd.from(this); + } + + // Crawl rules to build up the global total ordering of variables. + private void crawlConditions( + Rule rule, + Set statelessConditions, + Set statefulConditions + ) { + for (Condition condition : rule.getConditions()) { + if (!canonicalizedConditions.containsKey(condition)) { + // Create the RuleCondition and also canonicalize the underlying condition. + RulesBddCondition ruleCondition = RulesBddCondition.from(condition, ruleSet); + // Add a mapping between the original condition and the canonicalized condition. + canonicalizedConditions.put(condition, ruleCondition.getCondition()); + if (ruleCondition.isStateful()) { + statefulConditions.add(ruleCondition); + } else { + statelessConditions.add(ruleCondition); + } + } + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule subRule : treeRule.getRules()) { + crawlConditions(subRule, statelessConditions, statefulConditions); + } + } + } + + private void crawlRules(Rule rule, Set conditionIndices) { + for (Condition condition : rule.getConditions()) { + Condition c = Objects.requireNonNull(canonicalizedConditions.get(condition), "Condition not found"); + Integer idx = Objects.requireNonNull(conditionHashCons.get(c), "Condition not hashed"); + conditionIndices.add(idx); + } + + Rule leaf = null; + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule subRule : treeRule.getRules()) { + crawlRules(subRule, new LinkedHashSet<>(conditionIndices)); + } + } else if (!rule.getConditions().isEmpty()) { + leaf = createStandaloneResult(rule); + } else { + leaf = rule; + } + + if (leaf != null) { + int position = resultHashCons.computeIfAbsent(leaf, l -> { + results.add(l); + return results.size() - 1; + }); + paths.add(createPath(position, conditionIndices)); + } + } + + // Create a rule that strips off conditions and is just left with docs + the error or endpoint. + private static Rule createStandaloneResult(Rule rule) { + if (rule instanceof ErrorRule) { + ErrorRule e = (ErrorRule) rule; + return new ErrorRule( + ErrorRule.builder().description(e.getDocumentation().orElse(null)), + e.getError()); + } else if (rule instanceof EndpointRule) { + EndpointRule e = (EndpointRule) rule; + return new EndpointRule( + EndpointRule.builder().description(e.getDocumentation().orElse(null)), + e.getEndpoint()); + } else { + throw new UnsupportedOperationException("Unsupported result node: " + rule); + } + } + + private BddPath createPath(int leafIdx, Set conditionIndices) { + Set statefulConditions = new LinkedHashSet<>(); + Set statelessConditions = new TreeSet<>((a, b) -> { + int conditionComparison = ruleComparator(conditions.get(a), conditions.get(b)); + // fall back to index comparison to ensure uniqueness + return conditionComparison != 0 ? conditionComparison : Integer.compare(a, b); + }); + + for (Integer conditionIdx : conditionIndices) { + RulesBddCondition node = conditions.get(conditionIdx); + if (!node.isStateful()) { + statelessConditions.add(conditionIdx); + } else { + statefulConditions.add(conditionIdx); + } + } + + return new BddPath(leafIdx, statelessConditions, statefulConditions); + } + + private int ruleComparator(RulesBddCondition a, RulesBddCondition b) { + return Integer.compare(a.getComplexity(), b.getComplexity()); + } + + /** + * Represents a path through rule conditions to reach a specific result. + * + *

Contains both stateless conditions (sorted by complexity) and stateful conditions (ordered by dependency) + * that must be evaluated to reach the target leaf (endpoint or error). + */ + public static final class BddPath { + + // The endpoint or error index. + private final int leafIndex; + + // Conditions that create or use stateful bound variables and must be maintained in order. + private final Set statefulConditions; + + // Sort conditions based on complexity scores. + private final Set statelessConditions; + + private int hash; + + BddPath(int leafIndex, Set statelessConditions, Set statefulConditions) { + this.leafIndex = leafIndex; + this.statelessConditions = Collections.unmodifiableSet(statelessConditions); + this.statefulConditions = Collections.unmodifiableSet(statefulConditions); + } + + public Set getStatefulConditions() { + return statefulConditions; + } + + public Set getStatelessConditions() { + return statelessConditions; + } + + public int getLeafIndex() { + return leafIndex; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } + BddPath path = (BddPath) object; + return leafIndex == path.leafIndex + && statefulConditions.equals(path.statefulConditions) + && statelessConditions.equals(path.statelessConditions); + } + + @Override + public int hashCode() { + int result = hash; + if (result == 0) { + result = Objects.hash(leafIndex, statefulConditions, statelessConditions); + hash = result; + } + return result; + } + + @Override + public String toString() { + return "Path{statelessConditions=" + statelessConditions + ", statefulConditions=" + statefulConditions + + ", leafIndex=" + leafIndex + '}'; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java index 9f25f9ce426..77e5bdb36d0 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java @@ -21,10 +21,12 @@ import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.node.StringNode; import software.amazon.smithy.model.node.ToNode; +import software.amazon.smithy.rulesengine.analysis.HashConsGraph; import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.TypeCheck; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.bdd.RulesBdd; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; @@ -48,6 +50,7 @@ public final class EndpointRuleSet implements FromSourceLocation, ToNode, ToSmit private static final String VERSION = "version"; private static final String PARAMETERS = "parameters"; private static final String RULES = "rules"; + private static final String BDD = "bdd"; private static final class LazyEndpointComponentFactoryHolder { static final EndpointComponentFactory INSTANCE = EndpointComponentFactory.createServiceFactory( @@ -58,6 +61,7 @@ private static final class LazyEndpointComponentFactoryHolder { private final List rules; private final SourceLocation sourceLocation; private final String version; + private final RulesBdd bdd; private EndpointRuleSet(Builder builder) { super(); @@ -65,6 +69,7 @@ private EndpointRuleSet(Builder builder) { rules = builder.rules.copy(); sourceLocation = SmithyBuilder.requiredState("source", builder.getSourceLocation()); version = SmithyBuilder.requiredState(VERSION, builder.version); + bdd = builder.bdd; } /** @@ -90,14 +95,38 @@ public static EndpointRuleSet fromNode(Node node) throws RuleError { builder.parameters(Parameters.fromNode(objectNode.expectObjectMember(PARAMETERS))); objectNode.expectStringMember(VERSION, builder::version); - for (Node element : objectNode.expectArrayMember(RULES).getElements()) { - builder.addRule(context("while parsing rule", element, () -> EndpointRule.fromNode(element))); - } + objectNode.getArrayMember(RULES).ifPresent(rules -> { + for (Node element : rules) { + builder.addRule(context("while parsing rule", element, () -> EndpointRule.fromNode(element))); + } + }); + + objectNode.getObjectMember(BDD).ifPresent(o -> { + builder.bdd(RulesBdd.fromNode(o)); + }); return builder.build(); }); } + /** + * Convert the endpoint ruleset to BDD form if it isn't already. + * + * @return the current ruleset if in BDD form, otherwise a new ruleset instance in BDD form. + */ + public EndpointRuleSet toBddForm() { + if (bdd != null) { + return this; + } + + return builder() + .version(version) + .parameters(parameters) + .bdd(new HashConsGraph(this).getBdd()) + .sourceLocation(sourceLocation) + .build(); + } + @Override public SourceLocation getSourceLocation() { return sourceLocation; @@ -130,6 +159,15 @@ public String getVersion() { return version; } + /** + * Get the BDD for this rule-set, if it exists. + * + * @return the optionally present BDD definitions. + */ + public Optional getBdd() { + return Optional.ofNullable(bdd); + } + public Type typeCheck() { return typeCheck(new Scope<>()); } @@ -151,19 +189,27 @@ public Builder toBuilder() { .sourceLocation(getSourceLocation()) .parameters(parameters) .rules(rules) - .version(version); + .version(version) + .bdd(bdd); } @Override public Node toNode() { - ArrayNode.Builder rulesBuilder = ArrayNode.builder(); - rules.forEach(rulesBuilder::withValue); - - return ObjectNode.builder() + ObjectNode.Builder builder = ObjectNode.builder() .withMember(VERSION, version) - .withMember(PARAMETERS, parameters) - .withMember(RULES, rulesBuilder.build()) - .build(); + .withMember(PARAMETERS, parameters); + + if (!rules.isEmpty()) { + ArrayNode.Builder rulesBuilder = ArrayNode.builder(); + rules.forEach(rulesBuilder::withValue); + builder.withMember(RULES, rulesBuilder.build()); + } + + if (bdd != null) { + builder.withMember(BDD, bdd.toNode()); + } + + return builder.build(); } @Override @@ -175,12 +221,15 @@ public boolean equals(Object o) { return false; } EndpointRuleSet that = (EndpointRuleSet) o; - return rules.equals(that.rules) && parameters.equals(that.parameters) && version.equals(that.version); + return rules.equals(that.rules) + && parameters.equals(that.parameters) + && version.equals(that.version) + && Objects.equals(bdd, that.bdd); } @Override public int hashCode() { - return Objects.hash(rules, parameters, version); + return Objects.hash(rules, parameters, version, bdd); } @Override @@ -188,8 +237,13 @@ public String toString() { StringBuilder builder = new StringBuilder(); builder.append(String.format("version: %s%n", version)); builder.append("params: \n").append(StringUtils.indent(parameters.toString(), 2)); - builder.append("rules: \n"); - rules.forEach(rule -> builder.append(StringUtils.indent(rule.toString(), 2))); + if (!rules.isEmpty()) { + builder.append("rules: \n"); + rules.forEach(rule -> builder.append(StringUtils.indent(rule.toString(), 2))); + } + if (bdd != null) { + builder.append("bdd: \n").append(bdd); + } return builder.toString(); } @@ -242,6 +296,7 @@ public static class Builder extends RulesComponentBuilder conditions; + private final List results; + private final int[][] nodes; + private final int root; + + public RulesBdd(List conditions, List results, int[][] nodes, int root) { + this.conditions = conditions; + this.results = results; + this.nodes = nodes; + this.root = root; + + if (root < 0) { + int adjustedRoot = -(root + 1); + if (adjustedRoot >= results.size()) { + throw new IllegalArgumentException("Root node references out of bounds result: " + + adjustedRoot + " vs " + results.size()); + } + } else if (root >= nodes.length) { + throw new IllegalArgumentException("Root node references out of bounds node: " + root); + } + } + + /** + * Get the root node of the diagram. + * + *

The root node may be less than zero, indicating that there are no nodes, and a result is returned + * directly and unconditionally. + * + * @return the root node. + */ + public int getRootNode() { + return root; + } + + /** + * Get the list of conditions. + * + * @return conditions. + */ + public List getConditions() { + return conditions; + } + + /** + * Get the nodes in the BDD. + * + * @return BDD nodes. + */ + public int[][] getNodes() { + return nodes; + } + + /** + * Get a node by index. + * + * @param idx Node index. + * @return the node at this index. + */ + public int[] getNode(int idx) { + return nodes[idx]; + } + + /** + * Get a result by index (an ErrorRule or EndpointRule). + * + * @param idx Index of the result. + * @return the result at this index. + */ + public Rule getResult(int idx) { + // Account for how results are encoded as negative numbers in rule triples (e.g., -1 becomes result 0). + if (idx < 0) { + idx = decodeLeafReference(idx); + } + return results.get(idx); + } + + /** + * Get the rule results in the BDD (error and endpoint rules with no conditions). + * + * @return BDD rule results. + */ + public List getResults() { + return results; + } + + @Override + public String toString() { + StringBuilder s = new StringBuilder(); + s.append("RulesBdd{\n"); + s.append(" conditions:"); + for (int i = 0; i < conditions.size(); i++) { + s.append("\n ").append(i).append(": ").append(formatString(conditions.get(i))); + } + s.append("\n results:"); + for (int i = 0; i < results.size(); i++) { + s.append("\n ").append(i).append(": ").append(formatString(results.get(i))); + } + s.append("\n root: ").append(root); + s.append("\n nodes:"); + for (int i = 0; i < nodes.length; i++) { + s.append("\n ").append(i).append(": ").append(Arrays.toString(nodes[i])); + } + s.append("\n}"); + return s.toString(); + } + + private static String formatString(Object o) { + String s = o.toString(); + if (s.contains("\n")) { + s = s.replace("\n", "\n "); + } + return s.trim(); + } + + /** + * Create a BDD object from a Node. + * + * @param node Node to deserialize. + * @return the created BDD. + */ + public static RulesBdd fromNode(Node node) { + return RulesBddNode.fromNode(node); + } + + @Override + public Node toNode() { + return RulesBddNode.toNode(this); + } + + // Encodes a leaf index as a negative number for BDD storage + // Leaf indices are stored as -(index + LEAF_OFFSET) to distinguish from node indices. + private static int encodeLeafReference(int leafIndex) { + return -(leafIndex + LEAF_OFFSET); + } + + private static int decodeLeafReference(int encodedRef) { + return -encodedRef - LEAF_OFFSET; + } + + /** + * Create a BDD from a processed graph of tree-based rules. + * + * @param graph Graph to process. + * @return the BDD result. + */ + public static RulesBdd from(HashConsGraph graph) { + Objects.requireNonNull(graph, "RuleGraph is null"); + List nodesList = new ArrayList<>(); + List results = new ArrayList<>(graph.getResults()); + int fallbackLeafIndex = getFallbackLeafIndex(graph, results); + int root = buildBDDNode(graph.getPaths(), graph.getConditions(), 0, nodesList, fallbackLeafIndex); + + // Extract the list of conditions from wrapped RuleConditions since that's all we need in the BDD class. + List justConditions = new ArrayList<>(graph.getConditions().size()); + for (RulesBddCondition condition : graph.getConditions()) { + justConditions.add(condition.getCondition()); + } + + if (root < 0) { + // No nodes created? Return BDD with empty node array that points to a result. + return new RulesBdd(justConditions, results, new int[0][], root); + } else { + // There are nodes (the norm), so reverse the node order. + int[][] nodes = reverseNodeOrder(nodesList); + int adjustedRoot = nodesList.size() - 1 - root; + return new RulesBdd(justConditions, results, nodes, adjustedRoot); + } + } + + // Find the path, if any, that has no conditions and leads to the top-most fallback error condition. + private static int getFallbackLeafIndex(HashConsGraph graph, List results) { + int fallbackLeafIndex = -1; + + for (HashConsGraph.BddPath path : graph.getPaths()) { + if (path.getStatelessConditions().isEmpty() && path.getStatefulConditions().isEmpty()) { + if (fallbackLeafIndex != -1) { + throw new IllegalStateException("Multiple paths with no conditions"); + } + fallbackLeafIndex = path.getLeafIndex(); + } + } + + // Create a fallback leaf if one wasn't found. + if (fallbackLeafIndex == -1) { + results.add(ErrorRule.builder().error("Unable to resolve an endpoint")); + fallbackLeafIndex = results.size() - 1; + } + + return fallbackLeafIndex; + } + + /** + * Recursively constructs BDD nodes by partitioning rule paths based on condition evaluation. + * + * @return Node index (non-negative) or encoded leaf reference (negative). + */ + private static int buildBDDNode( + List paths, + List conditions, + int conditionIndex, + List nodesList, + int fallbackLeafIndex + ) { + // no paths left, use fallback + if (paths.isEmpty()) { + return encodeLeafReference(fallbackLeafIndex); + } + + // Skip conditions that no remaining path cares about, unless they're stateful. + conditionIndex = findNextRelevantCondition(paths, conditions, conditionIndex); + + // No more conditions to check: go to fallback + if (conditionIndex >= conditions.size()) { + return encodeLeafReference(paths.get(0).getLeafIndex()); + } + + // Short-circuit here if all paths lead to the same leaf and none of the conditions are stateful. + if (canShortCircuit(paths, conditions, conditionIndex)) { + return encodeLeafReference(paths.get(0).getLeafIndex()); + } + + return createDecisionNode(paths, conditions, conditionIndex, nodesList, fallbackLeafIndex); + } + + private static int findNextRelevantCondition( + List paths, + List conditions, + int conditionIndex + ) { + while (conditionIndex < conditions.size()) { + boolean anyPathCaresAboutCondition = false; + for (HashConsGraph.BddPath path : paths) { + if (path.getStatelessConditions().contains(conditionIndex) + || path.getStatefulConditions().contains(conditionIndex)) { + anyPathCaresAboutCondition = true; + break; + } + } + + boolean isStateful = conditions.get(conditionIndex).isStateful(); + + if (!anyPathCaresAboutCondition && !isStateful) { + // Skip this condition - no remaining path cares and it's not stateful + conditionIndex++; + } else { + break; // Found a relevant condition + } + } + return conditionIndex; + } + + private static boolean canShortCircuit(List paths, List conditions, int idx) { + // Short circuit if all paths lead to same leaf and no stateful condition needs processing. + return !hasStatefulConditionAtIndex(paths, conditions, idx) && allSameResultLeaf(paths); + } + + // Detects if the given list of paths all resolve to the same result. + private static boolean allSameResultLeaf(List paths) { + if (!paths.isEmpty()) { + int firstLeaf = paths.get(0).getLeafIndex(); + for (int pos = 1; pos < paths.size(); pos++) { + if (paths.get(pos).getLeafIndex() != firstLeaf) { + return false; + } + } + } + + return true; // Empty list trivially has all same result + } + + // Check if we need to process a stateful condition even if all paths lead to same leaf + private static boolean hasStatefulConditionAtIndex( + List paths, + List conditions, + int idx + ) { + if (idx >= conditions.size()) { + return false; + } else if (!conditions.get(idx).isStateful()) { + return false; + } + + // Check if any path contains this stateful condition + for (HashConsGraph.BddPath path : paths) { + if (path.getStatefulConditions().contains(idx)) { + return true; + } + } + + return false; + } + + private static int createDecisionNode( + List paths, + List conditions, + int conditionIndex, + List nodesList, + int fallbackLeafIndex + ) { + // Split paths based on current condition + List truePaths = new ArrayList<>(); + List falsePaths = new ArrayList<>(); + for (HashConsGraph.BddPath path : paths) { + if (isTruePath(path, conditionIndex)) { + truePaths.add(path); + } else { + falsePaths.add(path); + } + } + + // Special handling for stateful conditions that might be elided due to short-circuiting. + // BDDs optimize by eliminating redundant nodes (where true/false branches are identical), and by + // short-circuiting when all remaining paths lead to the same outcome. However, stateful conditions in our + // rules engine must always execute for side effects and dependency ordering. + // + // When truePaths.isEmpty(), no remaining paths require this condition to be true, meaning a standard BDD + // would skip it entirely. We create a pass-through node to ensure execution while maintaining the BDD + // decision structure. + if (truePaths.isEmpty() && conditions.get(conditionIndex).isStateful()) { + return createPassThroughNode(paths, conditions, conditionIndex, nodesList, fallbackLeafIndex); + } + + // Recursively build subtrees + int trueTarget = buildBDDNode(truePaths, conditions, conditionIndex + 1, nodesList, fallbackLeafIndex); + int falseTarget = buildBDDNode(falsePaths, conditions, conditionIndex + 1, nodesList, fallbackLeafIndex); + + // Create this node + int currentNodeIndex = nodesList.size(); + int[] node = {conditionIndex, trueTarget, falseTarget}; + nodesList.add(node); + + return currentNodeIndex; + } + + private static boolean isTruePath(HashConsGraph.BddPath path, int conditionIndex) { + return path.getStatelessConditions().contains(conditionIndex) || + path.getStatefulConditions().contains(conditionIndex); + } + + // Creates a pass-through decision node for stateful conditions that must execute even when no paths require + // them to be true. + private static int createPassThroughNode( + List paths, + List conditions, + int conditionIndex, + List nodesList, + int fallbackLeafIndex + ) { + int target = buildBDDNode(paths, conditions, conditionIndex + 1, nodesList, fallbackLeafIndex); + int currentNodeIndex = nodesList.size(); + int[] node = {conditionIndex, target, target}; // Both true and false go to same target + nodesList.add(node); + return currentNodeIndex; + } + + // --------- Node reversal methods --------- + + // Reverses the order of nodes in the BDD to make it easier to read, and more cache friendly. + // The root node becomes index 0, and all node references are updated accordingly. + private static int[][] reverseNodeOrder(List originalNodes) { + int nodeCount = originalNodes.size(); + int[][] reversedNodes = new int[nodeCount][]; + + // Create mapping from old index to new index. + // - Root node (originally last) becomes index 0 + // - Last node (originally first) becomes index `nodeCount - 1` + int[] indexMapping = new int[nodeCount]; + for (int i = 0; i < nodeCount; i++) { + indexMapping[i] = nodeCount - 1 - i; + } + + // Copy nodes in reverse order and update their references + for (int oldIndex = 0; oldIndex < nodeCount; oldIndex++) { + int newIndex = indexMapping[oldIndex]; + int[] originalNode = originalNodes.get(oldIndex); + + int condition = originalNode[CONDITION_INDEX]; + int originalTrueTarget = originalNode[TRUE_TARGET]; + int originalFalseTarget = originalNode[FALSE_TARGET]; + + // Update node references + int newTrueTarget = updateNodeReference(originalTrueTarget, indexMapping); + int newFalseTarget = updateNodeReference(originalFalseTarget, indexMapping); + + reversedNodes[newIndex] = new int[]{condition, newTrueTarget, newFalseTarget}; + } + + return reversedNodes; + } + + // Updates a node reference when reordering nodes. Leaves are negative and aren't changed. + private static int updateNodeReference(int originalRef, int[] indexMapping) { + return originalRef < 0 ? originalRef : indexMapping[originalRef]; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddCondition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddCondition.java new file mode 100644 index 00000000000..2053a5405a1 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddCondition.java @@ -0,0 +1,168 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rulesengine.language.syntax.bdd; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +public final class RulesBddCondition { + + // Heuristic based complexity score for each method. + private static final int UNKNOWN_FUNCTION_COMPLEXITY = 4; + + private static final Map COMPLEXITY_ASSIGNMENTS = new HashMap<>(); + static { + COMPLEXITY_ASSIGNMENTS.put("isSet", 1); + COMPLEXITY_ASSIGNMENTS.put("not", 2); + COMPLEXITY_ASSIGNMENTS.put("booleanEquals", 3); + COMPLEXITY_ASSIGNMENTS.put("stringEquals", 4); + COMPLEXITY_ASSIGNMENTS.put("substring", 5); + COMPLEXITY_ASSIGNMENTS.put("aws.partition", 6); + COMPLEXITY_ASSIGNMENTS.put("getAttr", 7); + COMPLEXITY_ASSIGNMENTS.put("uriEncode", 8); + COMPLEXITY_ASSIGNMENTS.put("aws.parseArn", 9); + COMPLEXITY_ASSIGNMENTS.put("isValidHostLabel", 10); + COMPLEXITY_ASSIGNMENTS.put("parseURL", 11); + } + + private final Condition condition; + private boolean isStateful; + private int complexity = 0; + private int hash = 0; + + private RulesBddCondition(Condition condition, EndpointRuleSet ruleSet) { + this.condition = condition; + // Conditions that assign a value are always considered stateful. + isStateful = condition.getResult().isPresent(); + crawlCondition(ruleSet, 0, condition.getFunction()); + } + + public static RulesBddCondition from(Condition condition, EndpointRuleSet ruleSet) { + return new RulesBddCondition(canonicalizeCondition(condition), ruleSet); + } + + // Canonicalize conditions such that variable references for booleanEquals and stringEquals come before + // a literal. This ensures that these commutative functions count as a single variable and don't needlessly + // bloat the BDD table. + private static Condition canonicalizeCondition(Condition condition) { + Expression func = condition.getFunction(); + if (func instanceof BooleanEquals) { + BooleanEquals f = (BooleanEquals) func; + if (f.getArguments().get(0) instanceof Literal && !(f.getArguments().get(1) instanceof Literal)) { + // Flip the order to move the literal last. + return condition.toBuilder().fn(BooleanEquals.ofExpressions( + f.getArguments().get(1), + f.getArguments().get(0) + )).build(); + } + } else if (func instanceof StringEquals) { + StringEquals f = (StringEquals) func; + if (f.getArguments().get(0) instanceof Literal && !(f.getArguments().get(1) instanceof Literal)) { + // Flip the order to move the literal last. + return condition.toBuilder().fn(StringEquals.ofExpressions( + f.getArguments().get(1), + f.getArguments().get(0) + )).build(); + } + } + + return condition; + } + + public Condition getCondition() { + return condition; + } + + public int getComplexity() { + return complexity; + } + + public boolean isStateful() { + return isStateful; + } + + private void crawlCondition(EndpointRuleSet ruleSet, int depth, Expression e) { + // Every level of nesting is an automatic complexity++. + complexity++; + if (e instanceof Literal) { + walkLiteral(ruleSet, (Literal) e, depth); + } else if (e instanceof Reference) { + walkReference(ruleSet, (Reference) e); + } else if (e instanceof LibraryFunction) { + walkLibraryFunction(ruleSet, (LibraryFunction) e, depth); + } + } + + private void walkLiteral(EndpointRuleSet ruleSet, Literal l, int depth) { + if (l instanceof StringLiteral) { + StringLiteral s = (StringLiteral) l; + Template template = s.value(); + if (!template.isStatic()) { + complexity += 8; + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + // Need to check for dynamic variables that reference non-global params. + // Also add to the score for each parameter. + Template.Dynamic dynamic = (Template.Dynamic) part; + crawlCondition(ruleSet, depth + 1, dynamic.toExpression()); + } + } + } + } + } + + private void walkReference(EndpointRuleSet ruleSet, Reference r) { + // It's stateful if the name referenced here is not an input parameter name. + if (!ruleSet.getParameters().get(r.getName()).isPresent()) { + isStateful = true; + } + } + + private void walkLibraryFunction(EndpointRuleSet ruleSet, LibraryFunction l, int depth) { + // Track function complexity. + Integer functionComplexity = COMPLEXITY_ASSIGNMENTS.get(l.getName()); + complexity += functionComplexity != null ? functionComplexity : UNKNOWN_FUNCTION_COMPLEXITY; + // Crawl the arguments. + for (Expression arg : l.getArguments()) { + crawlCondition(ruleSet, depth + 1, arg); + } + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + RulesBddCondition that = (RulesBddCondition) object; + return isStateful == that.isStateful + && complexity == that.complexity + && Objects.equals(condition, that.condition); + } + } + + @Override + public int hashCode() { + int result = hash; + if (hash == 0) { + result = Objects.hash(condition, isStateful, complexity); + hash = result; + } + return result; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddNode.java new file mode 100644 index 00000000000..2d63474d967 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/bdd/RulesBddNode.java @@ -0,0 +1,81 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rulesengine.language.syntax.bdd; + +import java.util.ArrayList; +import java.util.List; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.ExpectationNotMetException; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * Handles converting BDD to and from Nodes. + */ +final class RulesBddNode { + + private RulesBddNode() {} + + static Node toNode(RulesBdd bdd) { + ObjectNode.Builder builder = Node.objectNodeBuilder(); + builder.withMember("root", bdd.getRootNode()); + + ArrayNode.Builder conditionBuilder = ArrayNode.builder(); + for (Condition condition : bdd.getConditions()) { + conditionBuilder.withValue(condition.toNode()); + } + builder.withMember("conditions", conditionBuilder.build()); + + ArrayNode.Builder resultBuilder = ArrayNode.builder(); + for (Rule result : bdd.getResults()) { + resultBuilder.withValue(result.toNode()); + } + builder.withMember("results", resultBuilder.build()); + + if (bdd.getNodes().length > 0) { + ArrayNode.Builder nodeBuilder = ArrayNode.builder(); + builder.withMember("nodes", nodeBuilder.build()); + } + + return Node.objectNode(); + } + + static RulesBdd fromNode(Node node) { + ObjectNode o = node.expectObjectNode(); + int root = o.expectNumberMember("root").getValue().intValue(); + + ArrayNode conditionsArray = o.expectArrayMember("conditions").expectArrayNode(); + List conditions = new ArrayList<>(conditionsArray.size()); + for (Node value : conditionsArray.getElements()) { + conditions.add(Condition.fromNode(value)); + } + + ArrayNode resultsArray = o.expectArrayMember("results").expectArrayNode(); + List results = new ArrayList<>(resultsArray.size()); + for (Node value : resultsArray.getElements()) { + results.add(Rule.fromNode(value)); + } + + ArrayNode nodesArray = o.expectArrayMember("nodes").expectArrayNode(); + int[][] nodes = new int[nodesArray.size()][]; + int row = 0; + for (Node value : nodesArray.getElements()) { + ArrayNode nodeArray = value.expectArrayNode(); + if (nodeArray.size() != 3) { + throw new ExpectationNotMetException("Each node array must have three numbers", nodeArray); + } + int[] nodeRow = new int[3]; + for (int i = 0; i < 3; i++) { + nodeRow[i] = nodeArray.get(i).get().expectNumberNode().getValue().intValue(); + } + nodes[row++] = nodeRow; + } + + return new RulesBdd(conditions, results, nodes, root); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java index f20c483fdf0..5aa266590a1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java @@ -21,7 +21,7 @@ public final class EndpointRule extends Rule { private final Endpoint endpoint; - EndpointRule(Rule.Builder builder, Endpoint endpoint) { + public EndpointRule(Rule.Builder builder, Endpoint endpoint) { super(builder); this.endpoint = endpoint; } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/HashConsGraphTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/HashConsGraphTest.java new file mode 100644 index 00000000000..c230af6c4f4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/analysis/HashConsGraphTest.java @@ -0,0 +1,70 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rulesengine.analysis; + +import java.nio.file.Files; +import java.nio.file.Paths; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.loader.ModelAssembler; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.shapes.ModelSerializer; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; + +public class HashConsGraphTest { + @Test + public void test() throws Exception { + String[] regional = { + "/Users/dowling/projects/aws-sdk-js-v3/codegen/sdk-codegen/aws-models/connect.json", + "com.amazonaws.connect#AmazonConnectService" + }; + String[] s3 = { + "/Users/dowling/projects/smithy-java/aws/client/aws-client-rulesengine/src/shared-resources/software/amazon/smithy/java/aws/client/rulesengine/s3.json", + "com.amazonaws.s3#AmazonS3" + }; + String[] inputs = s3; + + Model model = Model.assembler() + .addImport(Paths.get(inputs[0])) + .discoverModels() + .putProperty(ModelAssembler.ALLOW_UNKNOWN_TRAITS, true) + .assemble() + .unwrap(); + + ServiceShape service = model.expectShape(ShapeId.from(inputs[1]), ServiceShape.class); + EndpointRuleSet ruleSet = service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + HashConsGraph graph = new HashConsGraph(ruleSet); + + double paths = graph.getPaths().size(); + double totalConditions = 0; + int maxDepth = 0; + for (HashConsGraph.BddPath path : graph.getPaths()) { + totalConditions += path.getStatefulConditions().size() + path.getStatelessConditions().size(); + maxDepth = Math.max(maxDepth, path.getStatefulConditions().size() + path.getStatelessConditions().size()); + System.out.println(path); + } + + System.out.println("Max depth: " + maxDepth); + System.out.println("Average path conditions: " + (totalConditions / paths)); + System.out.println("BDD:"); + System.out.println(graph.getBdd()); + + + EndpointRuleSet updated = ruleSet.toBddForm(); + EndpointRuleSetTrait updatedTrait = service + .expectTrait(EndpointRuleSetTrait.class) + .toBuilder() + .ruleSet(updated.toNode()) + .build(); + ServiceShape updatedService = service.toBuilder().addTrait(updatedTrait).build(); + Model updatedModel = model.toBuilder().addShape(updatedService).build(); + + Files.write(Paths.get("/tmp/s3.json"), Node.prettyPrintJson(ModelSerializer.builder().build().serialize(updatedModel)).getBytes()); + } +}