From 6794a05e90f8c2186a06e0a5968de66d0b606ee7 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Fri, 24 Jan 2025 14:11:44 +0100 Subject: [PATCH 01/50] Add BO over graphs --- neps/optimizers/models/graphs/__init__.py | 0 .../models/graphs/context_managers.py | 64 ++++ .../examples/grakel_wl_usage_example.py | 50 +++ .../graph_aware_gp_optimization_example.py | 127 ++++++++ .../examples/single_task_gp_usage_example.py | 133 ++++++++ .../graph_aware_gp_optimization_example.py | 127 ++++++++ neps/optimizers/models/graphs/kernels.py | 290 ++++++++++++++++++ neps/optimizers/models/graphs/optimization.py | 96 ++++++ neps/optimizers/models/graphs/utils.py | 133 ++++++++ tests/test_graphs/__init__.py | 0 tests/test_graphs/test_botorch_wl_kernel.py | 102 ++++++ .../test_optimization_over_graphs.py | 214 +++++++++++++ tests/test_graphs/test_torch_wl_kernel.py | 243 +++++++++++++++ 13 files changed, 1579 insertions(+) create mode 100644 neps/optimizers/models/graphs/__init__.py create mode 100644 neps/optimizers/models/graphs/context_managers.py create mode 100644 neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py create mode 100644 neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py create mode 100644 neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py create mode 100644 neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py create mode 100644 neps/optimizers/models/graphs/kernels.py create mode 100644 neps/optimizers/models/graphs/optimization.py create mode 100644 neps/optimizers/models/graphs/utils.py create mode 100644 tests/test_graphs/__init__.py create mode 100644 tests/test_graphs/test_botorch_wl_kernel.py create mode 100644 tests/test_graphs/test_optimization_over_graphs.py create mode 100644 tests/test_graphs/test_torch_wl_kernel.py diff --git a/neps/optimizers/models/graphs/__init__.py b/neps/optimizers/models/graphs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/neps/optimizers/models/graphs/context_managers.py b/neps/optimizers/models/graphs/context_managers.py new file mode 100644 index 000000000..32d83abeb --- /dev/null +++ b/neps/optimizers/models/graphs/context_managers.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from botorch.models import SingleTaskGP + +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel + +if TYPE_CHECKING: + import networkx as nx + from botorch.models.gp_regression_mixed import Kernel + + +@contextmanager +def set_graph_lookup( + kernel_or_gp: Kernel | SingleTaskGP, + new_graphs: list[nx.Graph], + *, + append: bool = True, +) -> Iterator[None]: + """Context manager to temporarily set the graph lookup for a kernel or GP model. + + Args: + kernel_or_gp (Kernel | SingleTaskGP): The kernel or GP model whose graph lookup is + to be set. + new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup. + append (bool, optional): Whether to append the new graphs to the existing graph + lookup. Defaults to True. + """ + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + + # Determine the modules to update based on the input type + if isinstance(kernel_or_gp, SingleTaskGP): + modules = [k for k in kernel_or_gp.covar_module.sub_kernels() if + isinstance(k, BoTorchWLKernel)] + elif isinstance(kernel_or_gp, BoTorchWLKernel): + modules = [kernel_or_gp] + else: + assert hasattr(kernel_or_gp, + "sub_kernels"), "Kernel module must have sub_kernels method." + modules = [k for k in kernel_or_gp.sub_kernels() if + isinstance(k, BoTorchWLKernel)] + + # Save the current graph lookup and set the new graph lookup + for kern in modules: + if isinstance(kern, TorchWLKernel): + kern._get_node_neighbors.cache_clear() + kern._wl_iteration.cache_clear() + elif isinstance(kern, BoTorchWLKernel): + kern._compute_kernel.cache_clear() + + kernel_prev_graphs.append((kern, kern.graph_lookup)) + if append: + kern.set_graph_lookup([*kern.graph_lookup, *new_graphs]) + else: + kern.set_graph_lookup(new_graphs) + + yield + + # Restore the original graph lookup after the context manager exits + for kern, prev_graphs in kernel_prev_graphs: + kern.set_graph_lookup(prev_graphs) diff --git a/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py b/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py new file mode 100644 index 000000000..8b8c6a7ee --- /dev/null +++ b/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import networkx as nx +from grakel import WeisfeilerLehman, graph_from_networkx + + +def visualize_graph(G: nx.Graph): + """Visualize the NetworkX graph.""" + pos = nx.spring_layout(G) + nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue") + plt.show() + + +def add_labels(G: nx.Graph): + """Add labels to the nodes of the graph.""" + for node in G.nodes(): + G.nodes[node]["label"] = str(node) + + +# Create graphs +G1 = nx.Graph() +G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) +add_labels(G1) + +G2 = nx.Graph() +G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) +add_labels(G2) + +G3 = nx.Graph() +G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) +add_labels(G3) + +# Visualize the graphs +visualize_graph(G1) +visualize_graph(G2) +visualize_graph(G3) + +# Convert NetworkX graphs to Grakel format using graph_from_networkx +graph_list = list( + graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True) +) + +# Initialize the Weisfeiler-Lehman kernel +wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False) + +# Compute the kernel matrix +K = wl_kernel.fit_transform(graph_list) + +# Display the kernel matrix diff --git a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py new file mode 100644 index 000000000..1f9a8216b --- /dev/null +++ b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import time +from itertools import product +from typing import TYPE_CHECKING + +import networkx as nx +import torch +from botorch import fit_gpytorch_mll, settings +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel +from neps.optimizers.models.graphs.optimization import optimize_acqf_graph +from neps.optimizers.models.graphs.utils import min_max_scale, seed_all + +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +start_time = time.time() +settings.debug._set_state(True) +seed_all() + +TRAIN_CONFIGS = 50 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 1 +N_CATEGORICAL_VALUES_PER_CATEGORY = 2 +N_GRAPH = 1 + +assert N_GRAPH == 1, "This example only supports a single graph feature" + +# Generate random data +X = torch.cat([ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) +], dim=1) + +# Generate random graphs +graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + +# Generate random target values +y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 + +# Split into train and test sets +train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] +train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] +train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) + +train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) + +kernels = [ + ScaleKernel( + MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), + ScaleKernel(CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), + ScaleKernel(BoTorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(X.shape[1] - 1,))) +] + +# Create the Gaussian Process model +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) + +# Compute the posterior distribution +multivariate_normal: MultivariateNormal = gp.forward(train_x) + +# Making predictions on test data +with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs, append=False): + posterior = gp.forward(test_x) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + covar = posterior.covariance_matrix + +# Fit the GP model +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +# Define the acquisition function +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define the bounds for optimization +bounds = torch.tensor([ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ + len(X) - 1] * N_GRAPH, +]) + +# Define fixed categorical features +cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in + range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} +fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in + product(*cats_per_column.values())] + +# Optimize the acquisition function with graph sampling +best_candidate, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=train_graphs, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, + q=1, +) + +# Print the results + +# Clear caches after optimization to avoid memory leaks or unexpected behavior +BoTorchWLKernel._compute_kernel.cache_clear() +TorchWLKernel._get_node_neighbors.cache_clear() +TorchWLKernel._wl_iteration.cache_clear() diff --git a/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py b/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py new file mode 100644 index 000000000..b35cb3b51 --- /dev/null +++ b/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from itertools import product +from typing import TYPE_CHECKING + +import torch +from botorch import fit_gpytorch_mll +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from botorch.optim import optimize_acqf_mixed +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel + +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +TRAIN_CONFIGS = 10 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 2 +N_CATEGORICAL_VALUES_PER_CATEGORY = 3 + +kernels = [] + +# Create some random encoded hyperparameter configurations +X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +if N_NUMERICAL > 0: + X[:, :N_NUMERICAL] = torch.rand( + size=(TOTAL_CONFIGS, N_NUMERICAL), + dtype=torch.float64, + ) + +if N_CATEGORICAL > 0: + X[:, N_NUMERICAL:] = torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + size=(TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ) + +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + +if N_NUMERICAL > 0: + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=N_NUMERICAL, + active_dims=tuple(range(N_NUMERICAL)), + ), + ) + kernels.append(matern) + +if N_CATEGORICAL > 0: + hamming = ScaleKernel( + CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), + ), + ) + kernels.append(hamming) + +combined_num_cat_kernel = AdditiveKernel(*kernels) + +train_x = X[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS] + +test_x = X[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:] + +K_matrix = combined_num_cat_kernel.forward(train_x, train_x) + +train_y = train_y.unsqueeze(-1) +test_y = test_y.unsqueeze(-1) + +gp = SingleTaskGP( + train_X=train_x, + train_Y=train_y, + covar_module=combined_num_cat_kernel, +) + +multivariate_normal: MultivariateNormal = gp.forward(train_x) + +# =============== Fitting the GP using botorch =============== + + +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define bounds +bounds = torch.tensor( + [ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + ] +) + +# Setup categorical feature optimization +cats_per_column: dict[int, list[float]] = { + column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] + for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) +} + +# Generate fixed categorical features +fixed_cats: list[dict[int, float]] +if len(cats_per_column) == 1: + col, choice_indices = next(iter(cats_per_column.items())) + fixed_cats = [{col: i} for i in choice_indices] +else: + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo, strict=False)) + for combo in product(*cats_per_column.values()) + ] + +best_candidate, best_score = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + num_restarts=10, + raw_samples=10, + q=1, +) + diff --git a/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py new file mode 100644 index 000000000..1f9a8216b --- /dev/null +++ b/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import time +from itertools import product +from typing import TYPE_CHECKING + +import networkx as nx +import torch +from botorch import fit_gpytorch_mll, settings +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel +from neps.optimizers.models.graphs.optimization import optimize_acqf_graph +from neps.optimizers.models.graphs.utils import min_max_scale, seed_all + +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +start_time = time.time() +settings.debug._set_state(True) +seed_all() + +TRAIN_CONFIGS = 50 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 1 +N_CATEGORICAL_VALUES_PER_CATEGORY = 2 +N_GRAPH = 1 + +assert N_GRAPH == 1, "This example only supports a single graph feature" + +# Generate random data +X = torch.cat([ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) +], dim=1) + +# Generate random graphs +graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + +# Generate random target values +y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 + +# Split into train and test sets +train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] +train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] +train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) + +train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) + +kernels = [ + ScaleKernel( + MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), + ScaleKernel(CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), + ScaleKernel(BoTorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(X.shape[1] - 1,))) +] + +# Create the Gaussian Process model +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) + +# Compute the posterior distribution +multivariate_normal: MultivariateNormal = gp.forward(train_x) + +# Making predictions on test data +with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs, append=False): + posterior = gp.forward(test_x) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + covar = posterior.covariance_matrix + +# Fit the GP model +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +# Define the acquisition function +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define the bounds for optimization +bounds = torch.tensor([ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ + len(X) - 1] * N_GRAPH, +]) + +# Define fixed categorical features +cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in + range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} +fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in + product(*cats_per_column.values())] + +# Optimize the acquisition function with graph sampling +best_candidate, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=train_graphs, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, + q=1, +) + +# Print the results + +# Clear caches after optimization to avoid memory leaks or unexpected behavior +BoTorchWLKernel._compute_kernel.cache_clear() +TorchWLKernel._get_node_neighbors.cache_clear() +TorchWLKernel._wl_iteration.cache_clear() diff --git a/neps/optimizers/models/graphs/kernels.py b/neps/optimizers/models/graphs/kernels.py new file mode 100644 index 000000000..5e9a38375 --- /dev/null +++ b/neps/optimizers/models/graphs/kernels.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +import torch +from botorch.models.gp_regression_mixed import Kernel +from torch import Tensor +from torch.nn import Module + +from neps.optimizers.models.graphs.utils import graphs_to_tensors + +if TYPE_CHECKING: + import networkx as nx + + +class BoTorchWLKernel(Kernel): + """A custom kernel for Gaussian Processes using the Weisfeiler-Lehman (WL) algorithm. + + This kernel computes similarities between graphs based on their structural properties + using the WL algorithm. It is designed to be used with BoTorch and GPyTorch for + Gaussian Process regression. + + Args: + graph_lookup (list[nx.Graph]): List of NetworkX graphs. + n_iter (int, optional): Number of WL iterations to perform. Default is 5. + normalize (bool, optional): Whether to normalize the kernel matrix. + Default is True. + active_dims (tuple[int, ...]): Dimensions of the input to consider. + Not used in this kernel but included for compatibility with the base Kernel class. + **kwargs (Any): Additional arguments for the base Kernel class. + + Attributes: + graph_lookup (list[nx.Graph]): List of graphs used for kernel computation. + n_iter (int): Number of WL iterations. + normalize (bool): Whether to normalize the kernel matrix. + adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs. + label_cache (list[Tensor]): Cached initial node labels of the graphs. + """ + has_lengthscale = False + + def __init__( + self, + graph_lookup: list[nx.Graph], + n_iter: int = 5, + *, + normalize: bool = True, + active_dims: tuple[int, ...], + **kwargs: Any, + ) -> None: + super().__init__(active_dims=active_dims, **kwargs) + self.graph_lookup = graph_lookup + self.n_iter = n_iter + self.normalize = normalize + self._precompute_graph_data() + + def _precompute_graph_data(self) -> None: + """Precompute and cache adjacency matrices and initial node labels.""" + self.adjacency_cache, self.label_cache = graphs_to_tensors( + self.graph_lookup, device=self.device + ) + + def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + """Update the graph lookup and refresh the cached data.""" + self.graph_lookup = graph_lookup + self._precompute_graph_data() + + def forward( + self, + x1: Tensor, + x2: Tensor, + *, + diag: bool = False, + last_dim_is_batch: bool = False, + **params: Any, + ) -> Tensor: + """Compute kernel matrix containing pairwise similarities between graphs.""" + if last_dim_is_batch: + raise NotImplementedError("Batch dimension handling is not implemented.") + + if x1.ndim == 3: + return self._handle_batched_input(x1, x2, diag) + + indices1, indices2 = self._prepare_indices(x1, x2) + + return self._compute_kernel(tuple(indices1), tuple(indices2), diag) + + def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor: + """Handle computation for batched input tensors.""" + q_dim_size = x1.shape[0] + assert x2.shape[0] == q_dim_size + + out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) + for q in range(q_dim_size): + out[q] = self.forward(x1[q], x2[q], diag=diag) + return out + + def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]: + """Convert tensor indices to integer lists.""" + indices1 = x1.flatten().to(torch.int64).tolist() + indices2 = x2.flatten().to(torch.int64).tolist() + + # Check for missing graph indices (-1) and handle them + # Explanation: The index `-1` is used as a placeholder for "missing" or "invalid" + # graphs. This can occur when a graph feature is missing or undefined, such as + # during the exploration of new candidates where no corresponding graph is + # available in the `graph_lookup`. The kernel expects non-negative indices, so we + # need to convert `-1` to the index of the last graph in the lookup. + + # Use the last graph in the lookup as a placeholder + last_graph_idx = len(self.graph_lookup) - 1 + + if -1 in indices1: + # Replace any `-1` indices with the index of the last graph. + indices1 = [last_graph_idx if i == -1 else i for i in indices1] + + if -1 in indices2: + # Replace any `-1` indices with the index of the last graph. + indices2 = [last_graph_idx if i == -1 else i for i in indices2] + + return indices1, indices2 + + @lru_cache(maxsize=128) + def _compute_kernel( + self, + indices1: tuple[int, ...], + indices2: tuple[int, ...], + diag: bool, + ) -> Tensor: + """Compute the kernel matrix. + + Args: + indices1: Tuple of indices for the first set of graphs. + indices2: Tuple of indices for the second set of graphs. + diag: Whether to return only the diagonal of the kernel matrix. + + Returns: + A Tensor representing the kernel matrix. + """ + all_graphs = list(set(indices1).union(indices2)) + adj_matrices = [self.adjacency_cache[i] for i in all_graphs] + label_tensors = [self.label_cache[i] for i in all_graphs] + + # Compute full kernel matrix + K_full = self._compute_base_kernel(adj_matrices, label_tensors) + + # Map indices to their positions in all_graphs + idx1 = [all_graphs.index(i) for i in indices1] + idx2 = [all_graphs.index(i) for i in indices2] + + # Extract the relevant submatrix + K = K_full[idx1][:, idx2] + + # Return the diagonal if requested + if diag: + return torch.diag(K) + + return K + + def _compute_base_kernel( + self, adj_matrices: list[Tensor], label_tensors: list[Tensor] + ) -> Tensor: + """Compute the base kernel matrix using WL algorithm.""" + _kernel = TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) + return _kernel(adj_matrices, label_tensors) + + +class TorchWLKernel(Module): + """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. + + The WL Kernel is a graph kernel that measures similarity between graphs based on + their structural properties. It works by iteratively updating node labels based on + their neighborhoods and computing feature vectors from label distributions. + + Args: + n_iter: Number of WL iterations to perform + normalize: bool, optional. Whether to normalize the kernel matrix + + Attributes: + device: torch.device for computation (CPU/GPU) + label_dict: Mapping from node labels to numerical indices + label_counter: Counter for generating new label indices + """ + + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: + super().__init__() + self.n_iter = n_iter + self.normalize = normalize + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Keep track of labels across iterations + self.label_dict: dict[str, int] = {} + self.label_counter: int = 0 + + @lru_cache(maxsize=128) + def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]: + """Extract neighborhood information from adjacency matrix.""" + if adj.layout == torch.sparse_csr: + adj = adj.to_sparse_coo() + + adj = adj.coalesce() + rows, cols = adj.indices() + num_nodes = adj.size(0) + + neighbors: list[list[int]] = [[] for _ in range(num_nodes)] + for row, col in zip(rows.tolist(), cols.tolist(), strict=False): + neighbors[row].append(col) + + return neighbors + + @lru_cache(maxsize=128) + def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: + """Perform one WL iteration.""" + if not self.label_dict: + # Start new labels after initial ones + self.label_counter = int(labels.max().item()) + 1 + + num_nodes = labels.size(0) + new_labels: list[int] = [] + neighbors = self._get_node_neighbors(adj) + + for node_idx in range(num_nodes): + # Get current node label + node_label = int(labels[node_idx].item()) + neighbor_labels = sorted([int(labels[n].item()) for n in neighbors[node_idx]]) + + credential = f"{node_label},{neighbor_labels}" + + # Update label dictionary + new_labels.append( + self.label_dict.setdefault(credential, len(self.label_dict)) + ) + + return torch.tensor(new_labels, dtype=torch.int64, device=self.device) + + def _compute_feature_vector(self, all_labels: list[list[Tensor]]) -> Tensor: + """Compute the histogram feature vector for all graphs.""" + batch_size = len(all_labels[0]) + features: list[Tensor] = [] + + for iteration_labels in all_labels: + # Find maximum label value across all graphs in this iteration + max_label = int(max(label.max().item() for label in iteration_labels)) + 1 + + iter_features = torch.zeros((batch_size, max_label), device=self.device) + + # Compute label frequencies + for graph_idx, labels in enumerate(iteration_labels): + counts = torch.bincount(labels, minlength=max_label) + iter_features[graph_idx] = counts + + features.append(iter_features) + + return torch.cat(features, dim=1) + + def forward(self, adj_matrices: list[Tensor], label_tensors: list[Tensor]) -> Tensor: + """Compute WL kernel matrix for a list of graphs. + + Args: + adj_matrices: Precomputed sparse adjacency matrices for graphs. + label_tensors: Precomputed node label tensors for graphs. + + Returns: + Kernel matrix containing pairwise graph similarities. + """ + if len(adj_matrices) != len(label_tensors): + raise ValueError("Mismatch between adjacency matrices and label tensors.") + + # Reset label dictionary for new computation + self.label_dict = {} + # Store all label iterations + all_labels: list[list[Tensor]] = [label_tensors] + + # Perform WL iterations + for _ in range(self.n_iter): + new_labels = [ + self._wl_iteration(adj, labels) + for adj, labels in zip(adj_matrices, all_labels[-1], strict=False) + ] + all_labels.append(new_labels) + + # Compute feature vectors and kernel matrix (similarity matrix) + final_features = self._compute_feature_vector(all_labels) + kernel_matrix = torch.mm(final_features, final_features.t()) + + if self.normalize: + diag = torch.sqrt(torch.diag(kernel_matrix)) + kernel_matrix /= torch.outer(diag, diag) + + return kernel_matrix diff --git a/neps/optimizers/models/graphs/optimization.py b/neps/optimizers/models/graphs/optimization.py new file mode 100644 index 000000000..07542c3ae --- /dev/null +++ b/neps/optimizers/models/graphs/optimization.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from botorch.optim import optimize_acqf_mixed + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.utils import sample_graphs + +if TYPE_CHECKING: + import networkx as nx + from botorch.acquisition import AcquisitionFunction + + +def optimize_acqf_graph( + acq_function: AcquisitionFunction, + bounds: torch.Tensor, + fixed_features_list: list[dict[int, int]] | None = None, + num_graph_samples: int = 10, + train_graphs: list[nx.Graph] | None = None, + num_restarts: int = 10, + raw_samples: int = 1024, + q: int = 1, +) -> tuple[torch.Tensor, float]: + """Optimize an acquisition function with graph sampling. + + This function optimizes the acquisition function by sampling graphs from the training + set, temporarily updating the kernel's graph lookup, and evaluating the acquisition + function for each sampled graph. The best candidate and its corresponding acquisition + score are returned. + + Args: + acq_function (AcquisitionFunction): The acquisition function to optimize. + bounds (torch.Tensor): A 2 x d tensor of bounds for numerical and categorical + features, where d is the number of features. + fixed_features_list (list[dict[int, float]] | None): A list of dictionaries + specifying fixed categorical feature configurations. Each dictionary maps + feature indices to their fixed values. Defaults to None. + num_graph_samples (int): The number of graphs to sample from the training set. + Defaults to 10. + train_graphs (list[nx.Graph] | None): The original training graphs. If None, a + ValueError is raised. + num_restarts (int): The number of optimization restarts. Defaults to 10. + raw_samples (int): The number of raw samples to generate for optimization. + Defaults to 1024. + q (int): The number of candidates to generate. Defaults to 1. + + Returns: + tuple[torch.Tensor, float]: A tuple containing the best candidate (as a tensor) + and its corresponding acquisition score. + + Raises: + ValueError: If `train_graphs` is None. + """ + if train_graphs is None: + raise ValueError("train_graphs cannot be None.") + + # Sample graphs from the training set + sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + + # Initialize lists to store the best candidates and their scores + best_candidates, best_scores = [], [] + + # Get the index of the graph feature in the bounds + graph_idx = bounds.shape[1] - 1 + + # Iterate through each sampled graph + for graph in sampled_graphs: + # Temporarily set the graph lookup for the kernel + with set_graph_lookup(acq_function.model.covar_module, [graph], append=True): + # Iterate through each fixed feature configuration (if provided) + for fixed_features in fixed_features_list or [{}]: + # Add the graph index to the fixed features, indicating that the last + # graphin the lookup should be used + updated_fixed_features = {**fixed_features, graph_idx: -1.0} + + # Optimize the acquisition function with the updated fixed features + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[updated_fixed_features], + num_restarts=num_restarts, + raw_samples=raw_samples, + q=q, + ) + + # Store the candidates and their scores + best_candidates.append(candidates) + best_scores.append(scores) + + # Find the index of the best score + best_idx = torch.argmax(torch.tensor(best_scores)) + + # Return the best candidate and its score + return best_candidates[best_idx], best_scores[best_idx].item() diff --git a/neps/optimizers/models/graphs/utils.py b/neps/optimizers/models/graphs/utils.py new file mode 100644 index 000000000..22e9d8a64 --- /dev/null +++ b/neps/optimizers/models/graphs/utils.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import random + +import networkx as nx +import numpy as np +import torch + + +def seed_all(seed: int = 100) -> None: + """Seed all random generators for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Ensure reproducibility with CuDNN (may reduce performance) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def min_max_scale(tensor: torch.Tensor) -> torch.Tensor: + """Scale the input tensor to the range [0, 1].""" + min_vals = tensor.min(dim=0, keepdim=True).values + max_vals = tensor.max(dim=0, keepdim=True).values + return (tensor - min_vals) / (max_vals - min_vals) + + +def graphs_to_tensors( + graphs: list[nx.Graph], + device: torch.device | None = None +) -> tuple[list[torch.sparse.Tensor], list[torch.Tensor]]: + """Convert a list of NetworkX graphs into sparse adjacency matrices and label tensors. + + Args: + graphs (List[nx.Graph]): A list of NetworkX graphs. + device (torch.device | None): The device to place the tensors on. + Default is CPU. + + Returns: + Tuple[List[torch.sparse.Tensor], List[torch.Tensor]]: + A tuple containing: + - A list of sparse adjacency matrices. + - A list of label tensors. + """ + if device is None: + device = torch.device("cpu") + + adjacency_matrices = [] + label_tensors = [] + + # Create a consistent label mapping across all graphs + label_dict: dict[str, int] = {} + label_counter: int = 0 + + for graph in graphs: + # Create adjacency matrix + edges = list(graph.edges()) + num_nodes = graph.number_of_nodes() + + if not edges: + adj = torch.sparse_coo_tensor( + indices=torch.empty((2, 0), dtype=torch.long), + values=torch.empty(0), + size=(num_nodes, num_nodes), + device=device, + ).to_sparse_csr() + else: + edge_indices = edges + [(v, u) for u, v in edges] + rows, cols = zip(*edge_indices, strict=False) + indices = torch.tensor([rows, cols], dtype=torch.long, device=device) + values = torch.ones(len(edge_indices), dtype=torch.float, device=device) + adj = torch.sparse_coo_tensor( + indices, values, (num_nodes, num_nodes), device=device + ).to_sparse_csr() + + adjacency_matrices.append(adj) + + # Create label tensor + node_labels: list[int] = [] + for node in range(graph.number_of_nodes()): + if "label" in graph.nodes[node]: + label = graph.nodes[node]["label"] + if label not in label_dict: + label_dict[label] = label_counter + label_counter += 1 + node_labels.append(label_dict[label]) + else: + node_labels.append(node) + + label_tensors.append(torch.tensor(node_labels, dtype=torch.long, device=device)) + + return adjacency_matrices, label_tensors + + +def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: + """Sample graphs using random walks or edge modifications. + + Args: + graphs (list[nx.Graph]): Existing training graphs. + num_samples (int): Number of graph samples to generate. + + Returns: + list[nx.Graph]: Sampled graphs. + """ + sampled_graphs = [] + for _ in range(num_samples): + base_graph = random.choice(graphs) + sampled_graph = base_graph.copy() + + # More aggressive modifications + num_modifications = random.randint(2, 5) # Increase minimum modifications + for _ in range(num_modifications): + if random.random() > 0.3: # 70% chance to add edge + nodes = list(sampled_graph.nodes) + if len(nodes) >= 2: + u, v = random.sample(nodes, 2) + if not sampled_graph.has_edge(u, v): + sampled_graph.add_edge(u, v) + elif sampled_graph.edges: # 30% chance to remove edge + u, v = random.choice(list(sampled_graph.edges)) + sampled_graph.remove_edge(u, v) + + # Ensure the graph stays connected + if not nx.is_connected(sampled_graph): + components = list(nx.connected_components(sampled_graph)) + for i in range(len(components) - 1): + u = random.choice(list(components[i])) + v = random.choice(list(components[i + 1])) + sampled_graph.add_edge(u, v) + + sampled_graphs.append(sampled_graph) + + return sampled_graphs diff --git a/tests/test_graphs/__init__.py b/tests/test_graphs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_graphs/test_botorch_wl_kernel.py b/tests/test_graphs/test_botorch_wl_kernel.py new file mode 100644 index 000000000..2ded6237e --- /dev/null +++ b/tests/test_graphs/test_botorch_wl_kernel.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import networkx as nx +import pytest +import torch +from botorch.models.gp_regression_mixed import Kernel + +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel + + +def create_simple_graphs(num_graphs: int) -> list[nx.Graph]: + """Helper function to create a list of graphs.""" + graphs = [] + for _i in range(num_graphs): + G = nx.Graph() + G.add_nodes_from([0, 1, 2]) + G.add_edges_from([(0, 1), (1, 2)]) + graphs.append(G) + return graphs + + +class TestBoTorchWLKernel: + @pytest.fixture + def simple_graphs(self) -> list[nx.Graph]: + return create_simple_graphs(3) + + @pytest.fixture + def wl_kernel(self, simple_graphs: list[nx.Graph]) -> BoTorchWLKernel: + return BoTorchWLKernel( + graph_lookup=simple_graphs, + n_iter=2, + normalize=True, + active_dims=(0,), + ) + + def test_initialization( + self, wl_kernel: BoTorchWLKernel, simple_graphs: list[nx.Graph] + ) -> None: + """Test that the kernel is initialized correctly.""" + assert isinstance(wl_kernel, Kernel) + assert len(wl_kernel.graph_lookup) == len(simple_graphs) + assert wl_kernel.n_iter == 2 + assert wl_kernel.normalize is True + assert torch.equal(wl_kernel.active_dims, torch.tensor([0])) + + def test_precompute_graph_data(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that graph data is precomputed correctly.""" + assert hasattr(wl_kernel, "adjacency_cache") + assert hasattr(wl_kernel, "label_cache") + assert len(wl_kernel.adjacency_cache) == len(wl_kernel.graph_lookup) + assert len(wl_kernel.label_cache) == len(wl_kernel.graph_lookup) + + def test_set_graph_lookup(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that the graph lookup can be updated.""" + new_graphs = create_simple_graphs(2) + wl_kernel.set_graph_lookup(new_graphs) + assert len(wl_kernel.graph_lookup) == 2 + assert len(wl_kernel.adjacency_cache) == 2 + assert len(wl_kernel.label_cache) == 2 + + def test_forward_self_kernel(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for self-similarity.""" + x = torch.tensor([[0], [1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x, x) + assert K.shape == (3, 3) # Kernel matrix should be 3x3 + assert torch.allclose(K, K.T) # Kernel matrix should be symmetric + + def test_forward_cross_kernel(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for cross-similarity.""" + x1 = torch.tensor([[0], [1]], dtype=torch.float64) + x2 = torch.tensor([[1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x1, x2) + assert K.shape == (2, 2) # Kernel matrix should be 2x2 + + def test_forward_diagonal(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for diagonal only.""" + x = torch.tensor([[0], [1], [2]], dtype=torch.float64) + K = wl_kernel.forward(x, x, diag=True) + assert K.shape == (3,) # Diagonal should be a vector of length 3 + + def test_handle_negative_one_index(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the handling of the -1 index.""" + x = torch.tensor([[-1], [0], [1]], dtype=torch.float64) + K = wl_kernel.forward(x, x) + assert K.shape == (3, 3) # Kernel matrix should be 3x3 + # Ensure that -1 refers to the last graph + last_graph_idx = len(wl_kernel.graph_lookup) - 1 + assert torch.allclose(K[0, 0], K[last_graph_idx, last_graph_idx]) + + def test_forward_batched_input(self, wl_kernel: BoTorchWLKernel) -> None: + """Test the kernel computation for batched input.""" + x1 = torch.tensor([[[0], [1]], [[1], [2]]], dtype=torch.float64) + x2 = torch.tensor([[[1], [2]], [[0], [1]]], dtype=torch.float64) + K = wl_kernel.forward(x1, x2) + assert K.shape == (2, 2, 2) # Batched kernel matrix should be 2x2x2 + + def test_forward_invalid_input(self, wl_kernel: BoTorchWLKernel) -> None: + """Test that invalid input raises an error.""" + x1 = torch.tensor([[0], [1], [2]], dtype=torch.float64) + x2 = torch.tensor([[0], [1]], dtype=torch.float64) + with pytest.raises(NotImplementedError): + wl_kernel.forward(x1, x2, last_dim_is_batch=True) diff --git a/tests/test_graphs/test_optimization_over_graphs.py b/tests/test_graphs/test_optimization_over_graphs.py new file mode 100644 index 000000000..f4d89778f --- /dev/null +++ b/tests/test_graphs/test_optimization_over_graphs.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from itertools import product + +import networkx as nx +import pytest +import torch +from botorch import fit_gpytorch_mll +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.kernels import CategoricalKernel +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel, ScaleKernel + +from neps.optimizers.models.graphs.context_managers import set_graph_lookup +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel +from neps.optimizers.models.graphs.optimization import optimize_acqf_graph, sample_graphs +from neps.optimizers.models.graphs.utils import min_max_scale + + +class TestGraphOptimizationPipeline: + @pytest.fixture + def setup_data(self) -> dict: + """Fixture to set up common data for tests.""" + TRAIN_CONFIGS = 50 + TEST_CONFIGS = 10 + TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + + N_NUMERICAL = 2 + N_CATEGORICAL = 1 + N_CATEGORICAL_VALUES_PER_CATEGORY = 2 + N_GRAPH = 1 + + # Generate random data + X = torch.cat([ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, + (TOTAL_CONFIGS, N_CATEGORICAL), dtype=torch.float64), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) + ], dim=1) + + # Generate random graphs + graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + + # Generate random target values + y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 + + # Split into train and test sets + train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] + train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] + train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) + + # Scale the data + train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) + + return { + "train_x": train_x, + "test_x": test_x, + "train_graphs": train_graphs, + "test_graphs": test_graphs, + "train_y": train_y, + "test_y": test_y, + "N_NUMERICAL": N_NUMERICAL, + "N_CATEGORICAL": N_CATEGORICAL, + "N_CATEGORICAL_VALUES_PER_CATEGORY": N_CATEGORICAL_VALUES_PER_CATEGORY, + "N_GRAPH": N_GRAPH, + } + + def test_gp_fit_and_predict(self, setup_data: dict) -> None: + """Test fitting the GP and making predictions.""" + train_x = setup_data["train_x"] + train_y = setup_data["train_y"] + test_x = setup_data["test_x"] + train_graphs = setup_data["train_graphs"] + setup_data["test_graphs"] + + # Define the kernels + kernels = [ + ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]))), + ScaleKernel( + CategoricalKernel(ard_num_dims=setup_data["N_CATEGORICAL"], + active_dims=range(setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + + setup_data["N_CATEGORICAL"]) + ) + ), + ScaleKernel( + BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(train_x.shape[1] - 1,))) + ] + + # Create the GP model + gp = SingleTaskGP(train_X=train_x, train_Y=train_y, + covar_module=AdditiveKernel(*kernels)) + + # Fit the GP + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) + + # Make predictions on the test set + with torch.no_grad(): + posterior = gp.forward(test_x) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + + # Ensure predictions are in the correct shape (10, 1) + predictions = predictions.unsqueeze(-1) # Reshape to (10, 1) + + # Basic checks + assert predictions.shape == (setup_data["test_x"].shape[0], 1) + assert uncertainties.shape == (setup_data["test_x"].shape[0],) + + def test_acquisition_function_optimization(self, setup_data: dict) -> None: + """Test optimizing the acquisition function with graph sampling.""" + train_x = setup_data["train_x"] + train_y = setup_data["train_y"] + train_graphs = setup_data["train_graphs"] + + # Define the kernels + kernels = [ + ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]))), + ScaleKernel( + CategoricalKernel( + ard_num_dims=setup_data["N_CATEGORICAL"], + active_dims=range(setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + + setup_data["N_CATEGORICAL"]) + ) + ), + ScaleKernel( + BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(train_x.shape[1] - 1,))) + ] + + # Create the GP model + gp = SingleTaskGP(train_X=train_x, train_Y=train_y, + covar_module=AdditiveKernel(*kernels)) + + # Fit the GP + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) + + # Define the acquisition function + acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, + ) + + # Define bounds for optimization + bounds = torch.tensor([ + [0.0] * setup_data["N_NUMERICAL"] + [0.0] * setup_data["N_CATEGORICAL"] + [ + -1.0] * setup_data["N_GRAPH"], + [1.0] * setup_data["N_NUMERICAL"] + [ + float(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"] - 1)] * setup_data[ + "N_CATEGORICAL"] + [len(train_x) - 1] * setup_data["N_GRAPH"], + ]) + + # Define fixed categorical features + cats_per_column = {i: list(range(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"])) + for i in range(setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data[ + "N_CATEGORICAL"])} + fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in + product(*cats_per_column.values())] + + # Optimize the acquisition function + best_candidate, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=train_graphs, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, + q=1, + ) + + # Basic checks + assert best_candidate.shape == (1, train_x.shape[1]) + assert isinstance(best_score, float) + + def test_graph_sampling(self, setup_data: dict) -> None: + """Test the graph sampling functionality.""" + train_graphs = setup_data["train_graphs"] + num_samples = 5 + + # Sample graphs + sampled_graphs = sample_graphs(train_graphs, num_samples=num_samples) + + # Basic checks + assert len(sampled_graphs) == num_samples + for graph in sampled_graphs: + assert isinstance(graph, nx.Graph) + assert nx.is_connected(graph) + + def test_set_graph_lookup(self, setup_data: dict) -> None: + """Test the set_graph_lookup context manager.""" + train_graphs = setup_data["train_graphs"] + test_graphs = setup_data["test_graphs"] + + # Define the kernel + kernel = BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(0,)) + + # Use the context manager to temporarily set the graph lookup + with set_graph_lookup(kernel, test_graphs, append=True): + assert len(kernel.graph_lookup) == len(train_graphs) + len(test_graphs) + + # Check that the original graph lookup is restored + assert len(kernel.graph_lookup) == len(train_graphs) diff --git a/tests/test_graphs/test_torch_wl_kernel.py b/tests/test_graphs/test_torch_wl_kernel.py new file mode 100644 index 000000000..a30de8a30 --- /dev/null +++ b/tests/test_graphs/test_torch_wl_kernel.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import networkx as nx +import numpy as np +import pytest +import torch +from grakel import WeisfeilerLehman, graph_from_networkx + +from neps.optimizers.models.graphs.kernels import TorchWLKernel +from neps.optimizers.models.graphs.utils import graphs_to_tensors + + +class TestTorchWLKernel: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @pytest.fixture + def example_graphs_set(self) -> list[nx.Graph]: + # Create example graphs for testing + G1 = nx.Graph() + G1.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (3, 4)]) + for node in G1.nodes(): + G1.nodes[node]["label"] = str(node) + + G2 = nx.Graph() + G2.add_edges_from([(0, 1), (1, 2), (3, 4), (4, 0)]) + for node in G2.nodes(): + G2.nodes[node]["label"] = str(node) + + G3 = nx.Graph() + G3.add_edges_from([(0, 1), (1, 3), (3, 2), (2, 4), (4, 0), (1, 2)]) + for node in G3.nodes(): + G3.nodes[node]["label"] = str(node) + + return [G1, G2, G3] + + @pytest.fixture + def random_graphs_sets(self) -> list[list[nx.Graph]]: + # Set a seed for reproducibility + seed = 100 + np.random.seed(seed) + torch.manual_seed(seed) + random_graph_sets = [] + + # Generate 10 random sets of graphs + for _ in range(10): + # Number of graphs in the set (2 to 10) + num_graphs = np.random.randint(2, 11) + graph_set = [] + + for _ in range(num_graphs): + # Number of nodes in the graph (3 to 50) + num_nodes = np.random.randint(3, 51) + G = nx.Graph() + + # Add nodes with labels + for node in range(num_nodes): + G.add_node(node, label=str(node)) + + # Add random edges + for u in range(num_nodes): + for v in range(u + 1, num_nodes): + if np.random.rand() > 0.5: # 50% chance to add an edge + G.add_edge(u, v) + + graph_set.append(G) + + random_graph_sets.append(graph_set) + + return random_graph_sets + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 5, 10]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_against_grakel( + self, n_iter: int, normalize: bool, random_graphs_sets: list[list[nx.Graph]] + ) -> None: + for graph_set in random_graphs_sets: + adjacency_matrices, label_tensors = graphs_to_tensors( + graph_set, device=self.device) + + # Initialize Torch WL Kernel + torch_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = torch_kernel(adjacency_matrices, + label_tensors).cpu().numpy() + + # Initialize GraKel WL Kernel + grakel_graphs = list( + graph_from_networkx(graph_set, node_labels_tag="label", as_Graph=True)) + grakel_kernel = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_kernel.fit_transform(grakel_graphs) + + # Compare the kernel matrices + np.testing.assert_allclose( + torch_kernel_matrix, + grakel_kernel_matrix, + rtol=1e-5, + atol=1e-8, + err_msg=f"Kernel matrices differ for graph={graph_set}, n_iter={n_iter}" + ) + + def test_empty_graph(self) -> None: + G_empty = nx.Graph() + G_empty.add_node(0) + G_empty.nodes[0]["label"] = "0" + + adjacency_matrices, label_tensors = graphs_to_tensors([G_empty], + device=self.device) + + # Initialize kernel and compute + kernel = TorchWLKernel(n_iter=3, normalize=True) + kernel_matrix = kernel(adjacency_matrices, label_tensors) + + # For a single graph, should get a 1x1 matrix with value 1.0 + expected = torch.ones(1, 1, device=self.device) + torch.testing.assert_close(kernel_matrix, expected) + + def test_invalid_input(self) -> None: + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + with pytest.raises(ValueError, + match="Mismatch between adjacency matrices and label tensors"): + wl_kernel([], [torch.tensor([0])]) + + def test_kernel_on_single_node_graph(self) -> None: + G_single = nx.Graph() + G_single.add_node(0) + G_single.nodes[0]["label"] = "0" + + adjacency_matrices, label_tensors = graphs_to_tensors([G_single], + device=self.device) + + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + expected = torch.ones(1, 1, device=self.device) + torch.testing.assert_close(K, expected) + + def test_wl_kernel_with_empty_graph_and_reordered_edges( + self, random_graphs_sets: list[list[nx.Graph]] + ) -> None: + """Test the TorchWLKernel with an empty graph and a graph with reordered edges.""" + for graph_set in random_graphs_sets: + # Create an empty graph + G_empty = nx.Graph() + G_empty.add_node(0) + G_empty.nodes[0]["label"] = "0" + + # Select the first graph from the set to reorder its edges + G = graph_set[0] + G_reordered = nx.Graph() + + # Add all nodes from the original graph to G_reordered + for node in G.nodes(): + G_reordered.add_node(node, label=G.nodes[node]["label"]) + + # Reorder edges randomly + edges = list(G.edges()) + np.random.shuffle(edges) # Randomly shuffle the edges + G_reordered.add_edges_from(edges) + + # Combine the empty graph, original graph, and reordered graph + graphs = [G_empty, G, G_reordered] + adjacency_matrices, label_tensors = graphs_to_tensors( + graphs, device=self.device + ) + + # Initialize and compute the kernel + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K[1, 1], K[2, 2]), \ + "Kernel value for original and reordered graphs should be the same" + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_with_different_node_labels( + self, n_iter: int, normalize: bool, example_graphs_set: list[nx.Graph] + ) -> None: + graphs = [] + for i, G in enumerate(example_graphs_set): + G_copy = G.copy() + prefix = ["node_", "vertex_", "n"][i] + for node in G_copy.nodes(): + G_copy.nodes[node]["label"] = f"{prefix}{node}" + graphs.append(G_copy) + + adjacency_matrices, label_tensors = graphs_to_tensors(graphs, + device=self.device) + + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(adjacency_matrices, label_tensors).cpu().numpy() + + grakel_graphs = graph_from_networkx(graphs, node_labels_tag="label") + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + np.testing.assert_allclose( + torch_kernel_matrix, + grakel_kernel_matrix, + rtol=1e-5, + atol=1e-8, + err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}" + ) + + def test_wl_kernel_with_same_node_labels( + self, example_graphs_set: list[nx.Graph] + ) -> None: + """Test WL kernel behavior with same node labels but different structures. + + Even when all nodes have the same label, the WL kernel should: + 1. Produce a symmetric matrix + 2. Have 1.0 on the diagonal (self-similarity) + 3. Have off-diagonal values less than 1.0 (different structures) + 4. Maintain non-negative values (it's a valid kernel) + """ + graphs = [] + for G in example_graphs_set: + G_copy = G.copy() + for node in G_copy.nodes(): + G_copy.nodes[node]["label"] = "A" + graphs.append(G_copy) + + adjacency_matrices, label_tensors = graphs_to_tensors( + graphs, device=self.device) + + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(adjacency_matrices, label_tensors) + + # Check basic properties + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K, K.T, atol=1e-4), "Kernel matrix is not symmetric" + + # Check diagonal elements are 1 (normalized self-similarity) + assert torch.allclose(torch.diag(K), torch.ones_like(torch.diag(K)), atol=1e-4), \ + "Diagonal elements should be 1.0" + + # Check off-diagonal elements are less than 1 (different structures) + off_diag_mask = ~torch.eye(K.shape[0], dtype=torch.bool, device=self.device) + assert torch.all(K[off_diag_mask] < 1.0), \ + "Off-diagonal elements should be less than 1.0 for different structures" + + # Check all elements are non-negative (valid kernel) + assert torch.all(K >= 0), "Kernel values should be non-negative" From e94bbe8c5da4273a084de03294e21c630cbbb7b0 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Fri, 24 Jan 2025 14:43:00 +0100 Subject: [PATCH 02/50] Return the best graph from optimize_acqf_graph function --- .../graph_aware_gp_optimization_example.py | 6 +- .../graph_aware_gp_optimization_example.py | 127 ------------------ neps/optimizers/models/graphs/optimization.py | 32 +++-- .../test_optimization_over_graphs.py | 48 +++++-- 4 files changed, 63 insertions(+), 150 deletions(-) delete mode 100644 neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py diff --git a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py index 1f9a8216b..ffe7a31da 100644 --- a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py +++ b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py @@ -108,7 +108,7 @@ product(*cats_per_column.values())] # Optimize the acquisition function with graph sampling -best_candidate, best_score = optimize_acqf_graph( +best_candidate, best_graph, best_score = optimize_acqf_graph( acq_function=acq_function, bounds=bounds, fixed_features_list=fixed_cats, @@ -120,6 +120,10 @@ ) # Print the results +print(f"Best candidate: {best_candidate}") +print(f"Best graph: {best_graph}") +print(f"Best score: {best_score}") +print(f"Execution time: {time.time() - start_time:.2f} seconds") # Clear caches after optimization to avoid memory leaks or unexpected behavior BoTorchWLKernel._compute_kernel.cache_clear() diff --git a/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py deleted file mode 100644 index 1f9a8216b..000000000 --- a/neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py +++ /dev/null @@ -1,127 +0,0 @@ -from __future__ import annotations - -import time -from itertools import product -from typing import TYPE_CHECKING - -import networkx as nx -import torch -from botorch import fit_gpytorch_mll, settings -from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement -from botorch.models import SingleTaskGP -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from gpytorch import ExactMarginalLogLikelihood -from gpytorch.kernels import AdditiveKernel, MaternKernel - -from neps.optimizers.models.graphs.context_managers import set_graph_lookup -from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel -from neps.optimizers.models.graphs.optimization import optimize_acqf_graph -from neps.optimizers.models.graphs.utils import min_max_scale, seed_all - -if TYPE_CHECKING: - from gpytorch.distributions.multivariate_normal import MultivariateNormal - -start_time = time.time() -settings.debug._set_state(True) -seed_all() - -TRAIN_CONFIGS = 50 -TEST_CONFIGS = 10 -TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS - -N_NUMERICAL = 2 -N_CATEGORICAL = 1 -N_CATEGORICAL_VALUES_PER_CATEGORY = 2 -N_GRAPH = 1 - -assert N_GRAPH == 1, "This example only supports a single graph feature" - -# Generate random data -X = torch.cat([ - torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), - torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), - dtype=torch.float64), - torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) -], dim=1) - -# Generate random graphs -graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] - -# Generate random target values -y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 - -# Split into train and test sets -train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] -train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] -train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) - -train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) - -kernels = [ - ScaleKernel( - MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), - ScaleKernel(CategoricalKernel( - ard_num_dims=N_CATEGORICAL, - active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), - ScaleKernel(BoTorchWLKernel( - graph_lookup=train_graphs, n_iter=5, normalize=True, - active_dims=(X.shape[1] - 1,))) -] - -# Create the Gaussian Process model -gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) - -# Compute the posterior distribution -multivariate_normal: MultivariateNormal = gp.forward(train_x) - -# Making predictions on test data -with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs, append=False): - posterior = gp.forward(test_x) - predictions = posterior.mean - uncertainties = posterior.variance.sqrt() - covar = posterior.covariance_matrix - -# Fit the GP model -mll = ExactMarginalLogLikelihood(gp.likelihood, gp) -fit_gpytorch_mll(mll) - -# Define the acquisition function -acq_function = qLogNoisyExpectedImprovement( - model=gp, - X_baseline=train_x, - objective=LinearMCObjective(weights=torch.tensor([-1.0])), - prune_baseline=True, -) - -# Define the bounds for optimization -bounds = torch.tensor([ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, - [1.0] * N_NUMERICAL + [ - float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ - len(X) - 1] * N_GRAPH, -]) - -# Define fixed categorical features -cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in - range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} -fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in - product(*cats_per_column.values())] - -# Optimize the acquisition function with graph sampling -best_candidate, best_score = optimize_acqf_graph( - acq_function=acq_function, - bounds=bounds, - fixed_features_list=fixed_cats, - train_graphs=train_graphs, - num_graph_samples=2, - num_restarts=2, - raw_samples=16, - q=1, -) - -# Print the results - -# Clear caches after optimization to avoid memory leaks or unexpected behavior -BoTorchWLKernel._compute_kernel.cache_clear() -TorchWLKernel._get_node_neighbors.cache_clear() -TorchWLKernel._wl_iteration.cache_clear() diff --git a/neps/optimizers/models/graphs/optimization.py b/neps/optimizers/models/graphs/optimization.py index 07542c3ae..6ca43f093 100644 --- a/neps/optimizers/models/graphs/optimization.py +++ b/neps/optimizers/models/graphs/optimization.py @@ -22,13 +22,13 @@ def optimize_acqf_graph( num_restarts: int = 10, raw_samples: int = 1024, q: int = 1, -) -> tuple[torch.Tensor, float]: +) -> tuple[torch.Tensor, nx.Graph, float]: """Optimize an acquisition function with graph sampling. This function optimizes the acquisition function by sampling graphs from the training set, temporarily updating the kernel's graph lookup, and evaluating the acquisition - function for each sampled graph. The best candidate and its corresponding acquisition - score are returned. + function for each sampled graph. The best candidate, the best graph, and its + corresponding acquisition score are returned. Args: acq_function (AcquisitionFunction): The acquisition function to optimize. @@ -47,8 +47,8 @@ def optimize_acqf_graph( q (int): The number of candidates to generate. Defaults to 1. Returns: - tuple[torch.Tensor, float]: A tuple containing the best candidate (as a tensor) - and its corresponding acquisition score. + tuple[torch.Tensor, nx.Graph, float]: A tuple containing the best candidate + (as a tensor), the best graph, and its corresponding acquisition score. Raises: ValueError: If `train_graphs` is None. @@ -56,23 +56,22 @@ def optimize_acqf_graph( if train_graphs is None: raise ValueError("train_graphs cannot be None.") - # Sample graphs from the training set sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) - # Initialize lists to store the best candidates and their scores - best_candidates, best_scores = [], [] + best_candidates, best_graphs, best_scores = [], [], [] # Get the index of the graph feature in the bounds graph_idx = bounds.shape[1] - 1 - # Iterate through each sampled graph + # Todo: Instead of iterating over the graphs, optimize by putting all + # sampled graphs into the kernel and compute the scores in a single batch. + # Update the caching logic accordingly. for graph in sampled_graphs: - # Temporarily set the graph lookup for the kernel with set_graph_lookup(acq_function.model.covar_module, [graph], append=True): # Iterate through each fixed feature configuration (if provided) for fixed_features in fixed_features_list or [{}]: # Add the graph index to the fixed features, indicating that the last - # graphin the lookup should be used + # graph in the lookup should be used updated_fixed_features = {**fixed_features, graph_idx: -1.0} # Optimize the acquisition function with the updated fixed features @@ -85,12 +84,17 @@ def optimize_acqf_graph( q=q, ) - # Store the candidates and their scores + # Store the candidates, graphs, and their scores best_candidates.append(candidates) + best_graphs.append(graph) best_scores.append(scores) # Find the index of the best score best_idx = torch.argmax(torch.tensor(best_scores)) - # Return the best candidate and its score - return best_candidates[best_idx], best_scores[best_idx].item() + # Return the best candidate (without the graph index), the best graph, and its score + return ( + best_candidates[best_idx][:, :-1], + best_graphs[best_idx], + best_scores[best_idx].item() + ) diff --git a/tests/test_graphs/test_optimization_over_graphs.py b/tests/test_graphs/test_optimization_over_graphs.py index f4d89778f..49056a21a 100644 --- a/tests/test_graphs/test_optimization_over_graphs.py +++ b/tests/test_graphs/test_optimization_over_graphs.py @@ -168,7 +168,7 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None: product(*cats_per_column.values())] # Optimize the acquisition function - best_candidate, best_score = optimize_acqf_graph( + best_candidate, best_graph, best_score = optimize_acqf_graph( acq_function=acq_function, bounds=bounds, fixed_features_list=fixed_cats, @@ -179,9 +179,17 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None: q=1, ) - # Basic checks - assert best_candidate.shape == (1, train_x.shape[1]) - assert isinstance(best_score, float) + # Assertions for the acquisition function optimization + assert isinstance(best_candidate, + torch.Tensor), "Best candidate should be a tensor" + assert best_candidate.shape == (1, train_x.shape[1] - 1), \ + "Best candidate should have the correct shape (excluding the graph index)" + assert isinstance(best_graph, nx.Graph), "Best graph should be a NetworkX graph" + assert isinstance(best_score, float), "Best score should be a float" + + # Ensure the best candidate does not contain the graph index column + assert best_candidate.shape[1] == train_x.shape[1] - 1, \ + "Best candidate should not include the graph index column" def test_graph_sampling(self, setup_data: dict) -> None: """Test the graph sampling functionality.""" @@ -192,10 +200,34 @@ def test_graph_sampling(self, setup_data: dict) -> None: sampled_graphs = sample_graphs(train_graphs, num_samples=num_samples) # Basic checks - assert len(sampled_graphs) == num_samples - for graph in sampled_graphs: - assert isinstance(graph, nx.Graph) - assert nx.is_connected(graph) + assert len(sampled_graphs) == num_samples, \ + f"Expected {num_samples} sampled graphs, got {len(sampled_graphs)}" + assert all(isinstance(graph, nx.Graph) for graph in sampled_graphs), \ + "All sampled graphs should be NetworkX graphs" + assert all(nx.is_connected(graph) for graph in sampled_graphs), \ + "All sampled graphs should be connected" + + def test_min_max_scaling(self, setup_data: dict) -> None: + """Test the min-max scaling utility.""" + train_x = setup_data["train_x"] + + # Apply min-max scaling + scaled_train_x = min_max_scale(train_x) + + # Assertions for min-max scaling + assert torch.all(scaled_train_x >= 0), "Scaled values should be >= 0" + assert torch.all(scaled_train_x <= 1), "Scaled values should be <= 1" + assert scaled_train_x.shape == train_x.shape, \ + "Scaled data should have the same shape as the input data" + + # Check that the scaling is correct + for i in range(train_x.shape[1]): + col_min = torch.min(train_x[:, i]) + col_max = torch.max(train_x[:, i]) + if col_min != col_max: # Avoid division by zero + expected_scaled_col = (train_x[:, i] - col_min) / (col_max - col_min) + assert torch.allclose(scaled_train_x[:, i], expected_scaled_col), \ + f"Scaling is incorrect for column {i}" def test_set_graph_lookup(self, setup_data: dict) -> None: """Test the set_graph_lookup context manager.""" From 15877784748b7f0555fbfe69e7fc601a0754bf7b Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Fri, 24 Jan 2025 18:45:34 +0100 Subject: [PATCH 03/50] Refactor method that uses lru_cache into a standalone function --- .../models/graphs/context_managers.py | 8 +- .../graph_aware_gp_optimization_example.py | 12 +- neps/optimizers/models/graphs/kernels.py | 111 ++++++++++-------- 3 files changed, 69 insertions(+), 62 deletions(-) diff --git a/neps/optimizers/models/graphs/context_managers.py b/neps/optimizers/models/graphs/context_managers.py index 32d83abeb..e8f69db37 100644 --- a/neps/optimizers/models/graphs/context_managers.py +++ b/neps/optimizers/models/graphs/context_managers.py @@ -6,7 +6,7 @@ from botorch.models import SingleTaskGP -from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, compute_kernel if TYPE_CHECKING: import networkx as nx @@ -45,11 +45,7 @@ def set_graph_lookup( # Save the current graph lookup and set the new graph lookup for kern in modules: - if isinstance(kern, TorchWLKernel): - kern._get_node_neighbors.cache_clear() - kern._wl_iteration.cache_clear() - elif isinstance(kern, BoTorchWLKernel): - kern._compute_kernel.cache_clear() + compute_kernel.cache_clear() kernel_prev_graphs.append((kern, kern.graph_lookup)) if append: diff --git a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py index ffe7a31da..3c4c1f5fd 100644 --- a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py +++ b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py @@ -14,7 +14,7 @@ from gpytorch.kernels import AdditiveKernel, MaternKernel from neps.optimizers.models.graphs.context_managers import set_graph_lookup -from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, TorchWLKernel +from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, compute_kernel from neps.optimizers.models.graphs.optimization import optimize_acqf_graph from neps.optimizers.models.graphs.utils import min_max_scale, seed_all @@ -45,7 +45,7 @@ ], dim=1) # Generate random graphs -graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] +graphs = [nx.erdos_renyi_graph(50, 0.5) for _ in range(TOTAL_CONFIGS)] # Generate random target values y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 @@ -113,9 +113,9 @@ bounds=bounds, fixed_features_list=fixed_cats, train_graphs=train_graphs, - num_graph_samples=2, + num_graph_samples=16, num_restarts=2, - raw_samples=16, + raw_samples=32, q=1, ) @@ -126,6 +126,4 @@ print(f"Execution time: {time.time() - start_time:.2f} seconds") # Clear caches after optimization to avoid memory leaks or unexpected behavior -BoTorchWLKernel._compute_kernel.cache_clear() -TorchWLKernel._get_node_neighbors.cache_clear() -TorchWLKernel._wl_iteration.cache_clear() +compute_kernel.cache_clear() diff --git a/neps/optimizers/models/graphs/kernels.py b/neps/optimizers/models/graphs/kernels.py index 5e9a38375..6d970d835 100644 --- a/neps/optimizers/models/graphs/kernels.py +++ b/neps/optimizers/models/graphs/kernels.py @@ -14,6 +14,57 @@ import networkx as nx +@lru_cache(maxsize=128) +def compute_kernel( + adjacency_cache: tuple[Tensor, ...], + label_cache: tuple[Tensor, ...], + indices1: tuple[int, ...], + indices2: tuple[int, ...], + n_iter: int, + *, + diag: bool, + normalize: bool, +) -> Tensor: + """Compute the kernel matrix. + + This function is defined outside the class to leverage the `lru_cache` decorator, + which caches the results of expensive function calls and reuses them when the same + inputs occur again. + + Args: + adjacency_cache: Tuple of adjacency matrices for the graphs. + label_cache: Tuple of initial node labels for the graphs. + indices1: Tuple of indices for the first set of graphs. + indices2: Tuple of indices for the second set of graphs. + n_iter: Number of WL iterations. + diag: Whether to return only the diagonal of the kernel matrix. + normalize: Whether to normalize the kernel matrix. + + Returns: + A Tensor representing the kernel matrix. + """ + all_graphs = list(set(indices1).union(indices2)) + adj_matrices = [adjacency_cache[i] for i in all_graphs] + label_tensors = [label_cache[i] for i in all_graphs] + + # Compute full kernel matrix + _kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + K_full = _kernel(adj_matrices, label_tensors) + + # Map indices to their positions in all_graphs + idx1 = [all_graphs.index(i) for i in indices1] + idx2 = [all_graphs.index(i) for i in indices2] + + # Extract the relevant submatrix + K = K_full[idx1][:, idx2] + + # Return the diagonal if requested + if diag: + return torch.diag(K) + + return K + + class BoTorchWLKernel(Kernel): """A custom kernel for Gaussian Processes using the Weisfeiler-Lehman (WL) algorithm. @@ -79,13 +130,21 @@ def forward( raise NotImplementedError("Batch dimension handling is not implemented.") if x1.ndim == 3: - return self._handle_batched_input(x1, x2, diag) + return self._handle_batched_input(x1=x1, x2=x2, diag=diag) indices1, indices2 = self._prepare_indices(x1, x2) - return self._compute_kernel(tuple(indices1), tuple(indices2), diag) + return compute_kernel( + adjacency_cache=tuple(self.adjacency_cache), + label_cache=tuple(self.label_cache), + indices1=tuple(indices1), + indices2=tuple(indices2), + n_iter=self.n_iter, + diag=diag, + normalize=self.normalize, + ) - def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor: + def _handle_batched_input(self, x1: Tensor, x2: Tensor, *, diag: bool) -> Tensor: """Handle computation for batched input tensors.""" q_dim_size = x1.shape[0] assert x2.shape[0] == q_dim_size @@ -120,50 +179,6 @@ def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int] return indices1, indices2 - @lru_cache(maxsize=128) - def _compute_kernel( - self, - indices1: tuple[int, ...], - indices2: tuple[int, ...], - diag: bool, - ) -> Tensor: - """Compute the kernel matrix. - - Args: - indices1: Tuple of indices for the first set of graphs. - indices2: Tuple of indices for the second set of graphs. - diag: Whether to return only the diagonal of the kernel matrix. - - Returns: - A Tensor representing the kernel matrix. - """ - all_graphs = list(set(indices1).union(indices2)) - adj_matrices = [self.adjacency_cache[i] for i in all_graphs] - label_tensors = [self.label_cache[i] for i in all_graphs] - - # Compute full kernel matrix - K_full = self._compute_base_kernel(adj_matrices, label_tensors) - - # Map indices to their positions in all_graphs - idx1 = [all_graphs.index(i) for i in indices1] - idx2 = [all_graphs.index(i) for i in indices2] - - # Extract the relevant submatrix - K = K_full[idx1][:, idx2] - - # Return the diagonal if requested - if diag: - return torch.diag(K) - - return K - - def _compute_base_kernel( - self, adj_matrices: list[Tensor], label_tensors: list[Tensor] - ) -> Tensor: - """Compute the base kernel matrix using WL algorithm.""" - _kernel = TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) - return _kernel(adj_matrices, label_tensors) - class TorchWLKernel(Module): """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. @@ -192,7 +207,6 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: self.label_dict: dict[str, int] = {} self.label_counter: int = 0 - @lru_cache(maxsize=128) def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]: """Extract neighborhood information from adjacency matrix.""" if adj.layout == torch.sparse_csr: @@ -208,7 +222,6 @@ def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]: return neighbors - @lru_cache(maxsize=128) def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: """Perform one WL iteration.""" if not self.label_dict: From 10221a99b6553482c97748c15f3a108eb5e65d5a Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Fri, 24 Jan 2025 18:46:16 +0100 Subject: [PATCH 04/50] Remove examples --- .../examples/grakel_wl_usage_example.py | 50 ------- .../graph_aware_gp_optimization_example.py | 129 ----------------- .../examples/single_task_gp_usage_example.py | 133 ------------------ 3 files changed, 312 deletions(-) delete mode 100644 neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py delete mode 100644 neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py delete mode 100644 neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py diff --git a/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py b/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py deleted file mode 100644 index 8b8c6a7ee..000000000 --- a/neps/optimizers/models/graphs/examples/grakel_wl_usage_example.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -import matplotlib.pyplot as plt -import networkx as nx -from grakel import WeisfeilerLehman, graph_from_networkx - - -def visualize_graph(G: nx.Graph): - """Visualize the NetworkX graph.""" - pos = nx.spring_layout(G) - nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue") - plt.show() - - -def add_labels(G: nx.Graph): - """Add labels to the nodes of the graph.""" - for node in G.nodes(): - G.nodes[node]["label"] = str(node) - - -# Create graphs -G1 = nx.Graph() -G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) -add_labels(G1) - -G2 = nx.Graph() -G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) -add_labels(G2) - -G3 = nx.Graph() -G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) -add_labels(G3) - -# Visualize the graphs -visualize_graph(G1) -visualize_graph(G2) -visualize_graph(G3) - -# Convert NetworkX graphs to Grakel format using graph_from_networkx -graph_list = list( - graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True) -) - -# Initialize the Weisfeiler-Lehman kernel -wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False) - -# Compute the kernel matrix -K = wl_kernel.fit_transform(graph_list) - -# Display the kernel matrix diff --git a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py b/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py deleted file mode 100644 index 3c4c1f5fd..000000000 --- a/neps/optimizers/models/graphs/examples/graph_aware_gp_optimization_example.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import annotations - -import time -from itertools import product -from typing import TYPE_CHECKING - -import networkx as nx -import torch -from botorch import fit_gpytorch_mll, settings -from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement -from botorch.models import SingleTaskGP -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from gpytorch import ExactMarginalLogLikelihood -from gpytorch.kernels import AdditiveKernel, MaternKernel - -from neps.optimizers.models.graphs.context_managers import set_graph_lookup -from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, compute_kernel -from neps.optimizers.models.graphs.optimization import optimize_acqf_graph -from neps.optimizers.models.graphs.utils import min_max_scale, seed_all - -if TYPE_CHECKING: - from gpytorch.distributions.multivariate_normal import MultivariateNormal - -start_time = time.time() -settings.debug._set_state(True) -seed_all() - -TRAIN_CONFIGS = 50 -TEST_CONFIGS = 10 -TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS - -N_NUMERICAL = 2 -N_CATEGORICAL = 1 -N_CATEGORICAL_VALUES_PER_CATEGORY = 2 -N_GRAPH = 1 - -assert N_GRAPH == 1, "This example only supports a single graph feature" - -# Generate random data -X = torch.cat([ - torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), - torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), - dtype=torch.float64), - torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) -], dim=1) - -# Generate random graphs -graphs = [nx.erdos_renyi_graph(50, 0.5) for _ in range(TOTAL_CONFIGS)] - -# Generate random target values -y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 - -# Split into train and test sets -train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] -train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] -train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) - -train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) - -kernels = [ - ScaleKernel( - MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), - ScaleKernel(CategoricalKernel( - ard_num_dims=N_CATEGORICAL, - active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), - ScaleKernel(BoTorchWLKernel( - graph_lookup=train_graphs, n_iter=5, normalize=True, - active_dims=(X.shape[1] - 1,))) -] - -# Create the Gaussian Process model -gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) - -# Compute the posterior distribution -multivariate_normal: MultivariateNormal = gp.forward(train_x) - -# Making predictions on test data -with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs, append=False): - posterior = gp.forward(test_x) - predictions = posterior.mean - uncertainties = posterior.variance.sqrt() - covar = posterior.covariance_matrix - -# Fit the GP model -mll = ExactMarginalLogLikelihood(gp.likelihood, gp) -fit_gpytorch_mll(mll) - -# Define the acquisition function -acq_function = qLogNoisyExpectedImprovement( - model=gp, - X_baseline=train_x, - objective=LinearMCObjective(weights=torch.tensor([-1.0])), - prune_baseline=True, -) - -# Define the bounds for optimization -bounds = torch.tensor([ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, - [1.0] * N_NUMERICAL + [ - float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ - len(X) - 1] * N_GRAPH, -]) - -# Define fixed categorical features -cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in - range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} -fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in - product(*cats_per_column.values())] - -# Optimize the acquisition function with graph sampling -best_candidate, best_graph, best_score = optimize_acqf_graph( - acq_function=acq_function, - bounds=bounds, - fixed_features_list=fixed_cats, - train_graphs=train_graphs, - num_graph_samples=16, - num_restarts=2, - raw_samples=32, - q=1, -) - -# Print the results -print(f"Best candidate: {best_candidate}") -print(f"Best graph: {best_graph}") -print(f"Best score: {best_score}") -print(f"Execution time: {time.time() - start_time:.2f} seconds") - -# Clear caches after optimization to avoid memory leaks or unexpected behavior -compute_kernel.cache_clear() diff --git a/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py b/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py deleted file mode 100644 index b35cb3b51..000000000 --- a/neps/optimizers/models/graphs/examples/single_task_gp_usage_example.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import annotations - -from itertools import product -from typing import TYPE_CHECKING - -import torch -from botorch import fit_gpytorch_mll -from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement -from botorch.models import SingleTaskGP -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from botorch.optim import optimize_acqf_mixed -from gpytorch import ExactMarginalLogLikelihood -from gpytorch.kernels import AdditiveKernel, MaternKernel - -if TYPE_CHECKING: - from gpytorch.distributions.multivariate_normal import MultivariateNormal - -TRAIN_CONFIGS = 10 -TEST_CONFIGS = 10 -TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS - -N_NUMERICAL = 2 -N_CATEGORICAL = 2 -N_CATEGORICAL_VALUES_PER_CATEGORY = 3 - -kernels = [] - -# Create some random encoded hyperparameter configurations -X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) -if N_NUMERICAL > 0: - X[:, :N_NUMERICAL] = torch.rand( - size=(TOTAL_CONFIGS, N_NUMERICAL), - dtype=torch.float64, - ) - -if N_CATEGORICAL > 0: - X[:, N_NUMERICAL:] = torch.randint( - 0, - N_CATEGORICAL_VALUES_PER_CATEGORY, - size=(TOTAL_CONFIGS, N_CATEGORICAL), - dtype=torch.float64, - ) - -y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) - -if N_NUMERICAL > 0: - matern = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=N_NUMERICAL, - active_dims=tuple(range(N_NUMERICAL)), - ), - ) - kernels.append(matern) - -if N_CATEGORICAL > 0: - hamming = ScaleKernel( - CategoricalKernel( - ard_num_dims=N_CATEGORICAL, - active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), - ), - ) - kernels.append(hamming) - -combined_num_cat_kernel = AdditiveKernel(*kernels) - -train_x = X[:TRAIN_CONFIGS] -train_y = y[:TRAIN_CONFIGS] - -test_x = X[TRAIN_CONFIGS:] -test_y = y[TRAIN_CONFIGS:] - -K_matrix = combined_num_cat_kernel.forward(train_x, train_x) - -train_y = train_y.unsqueeze(-1) -test_y = test_y.unsqueeze(-1) - -gp = SingleTaskGP( - train_X=train_x, - train_Y=train_y, - covar_module=combined_num_cat_kernel, -) - -multivariate_normal: MultivariateNormal = gp.forward(train_x) - -# =============== Fitting the GP using botorch =============== - - -mll = ExactMarginalLogLikelihood(gp.likelihood, gp) -fit_gpytorch_mll(mll) - -acq_function = qLogNoisyExpectedImprovement( - model=gp, - X_baseline=train_x, - objective=LinearMCObjective(weights=torch.tensor([-1.0])), - prune_baseline=True, -) - -# Define bounds -bounds = torch.tensor( - [ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, - [1.0] * N_NUMERICAL + [ - float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL - ] -) - -# Setup categorical feature optimization -cats_per_column: dict[int, list[float]] = { - column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] - for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) -} - -# Generate fixed categorical features -fixed_cats: list[dict[int, float]] -if len(cats_per_column) == 1: - col, choice_indices = next(iter(cats_per_column.items())) - fixed_cats = [{col: i} for i in choice_indices] -else: - fixed_cats = [ - dict(zip(cats_per_column.keys(), combo, strict=False)) - for combo in product(*cats_per_column.values()) - ] - -best_candidate, best_score = optimize_acqf_mixed( - acq_function=acq_function, - bounds=bounds, - fixed_features_list=fixed_cats, - num_restarts=10, - raw_samples=10, - q=1, -) - From 10247b2462af43311a0bb4ee390f5f3c6d26edab Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Mon, 27 Jan 2025 08:29:48 +0100 Subject: [PATCH 05/50] fix: ruff format --- .../models/graphs/context_managers.py | 17 +- neps/optimizers/models/graphs/kernels.py | 1 + neps/optimizers/models/graphs/optimization.py | 2 +- neps/optimizers/models/graphs/utils.py | 3 +- .../test_optimization_over_graphs.py | 150 ++++++++++++------ tests/test_graphs/test_torch_wl_kernel.py | 45 +++--- 6 files changed, 140 insertions(+), 78 deletions(-) diff --git a/neps/optimizers/models/graphs/context_managers.py b/neps/optimizers/models/graphs/context_managers.py index e8f69db37..56b643595 100644 --- a/neps/optimizers/models/graphs/context_managers.py +++ b/neps/optimizers/models/graphs/context_managers.py @@ -33,15 +33,20 @@ def set_graph_lookup( # Determine the modules to update based on the input type if isinstance(kernel_or_gp, SingleTaskGP): - modules = [k for k in kernel_or_gp.covar_module.sub_kernels() if - isinstance(k, BoTorchWLKernel)] + modules = [ + k + for k in kernel_or_gp.covar_module.sub_kernels() + if isinstance(k, BoTorchWLKernel) + ] elif isinstance(kernel_or_gp, BoTorchWLKernel): modules = [kernel_or_gp] else: - assert hasattr(kernel_or_gp, - "sub_kernels"), "Kernel module must have sub_kernels method." - modules = [k for k in kernel_or_gp.sub_kernels() if - isinstance(k, BoTorchWLKernel)] + assert hasattr(kernel_or_gp, "sub_kernels"), ( + "Kernel module must have sub_kernels method." + ) + modules = [ + k for k in kernel_or_gp.sub_kernels() if isinstance(k, BoTorchWLKernel) + ] # Save the current graph lookup and set the new graph lookup for kern in modules: diff --git a/neps/optimizers/models/graphs/kernels.py b/neps/optimizers/models/graphs/kernels.py index 6d970d835..64f4a7f71 100644 --- a/neps/optimizers/models/graphs/kernels.py +++ b/neps/optimizers/models/graphs/kernels.py @@ -88,6 +88,7 @@ class BoTorchWLKernel(Kernel): adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs. label_cache (list[Tensor]): Cached initial node labels of the graphs. """ + has_lengthscale = False def __init__( diff --git a/neps/optimizers/models/graphs/optimization.py b/neps/optimizers/models/graphs/optimization.py index 6ca43f093..c036f5816 100644 --- a/neps/optimizers/models/graphs/optimization.py +++ b/neps/optimizers/models/graphs/optimization.py @@ -96,5 +96,5 @@ def optimize_acqf_graph( return ( best_candidates[best_idx][:, :-1], best_graphs[best_idx], - best_scores[best_idx].item() + best_scores[best_idx].item(), ) diff --git a/neps/optimizers/models/graphs/utils.py b/neps/optimizers/models/graphs/utils.py index 22e9d8a64..056921554 100644 --- a/neps/optimizers/models/graphs/utils.py +++ b/neps/optimizers/models/graphs/utils.py @@ -26,8 +26,7 @@ def min_max_scale(tensor: torch.Tensor) -> torch.Tensor: def graphs_to_tensors( - graphs: list[nx.Graph], - device: torch.device | None = None + graphs: list[nx.Graph], device: torch.device | None = None ) -> tuple[list[torch.sparse.Tensor], list[torch.Tensor]]: """Convert a list of NetworkX graphs into sparse adjacency matrices and label tensors. diff --git a/tests/test_graphs/test_optimization_over_graphs.py b/tests/test_graphs/test_optimization_over_graphs.py index 49056a21a..958031af0 100644 --- a/tests/test_graphs/test_optimization_over_graphs.py +++ b/tests/test_graphs/test_optimization_over_graphs.py @@ -32,12 +32,19 @@ def setup_data(self) -> dict: N_GRAPH = 1 # Generate random data - X = torch.cat([ - torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), - torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, - (TOTAL_CONFIGS, N_CATEGORICAL), dtype=torch.float64), - torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) - ], dim=1) + X = torch.cat( + [ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1), + ], + dim=1, + ) # Generate random graphs graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] @@ -76,23 +83,36 @@ def test_gp_fit_and_predict(self, setup_data: dict) -> None: # Define the kernels kernels = [ - ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=setup_data["N_NUMERICAL"], - active_dims=range(setup_data["N_NUMERICAL"]))), ScaleKernel( - CategoricalKernel(ard_num_dims=setup_data["N_CATEGORICAL"], - active_dims=range(setup_data["N_NUMERICAL"], - setup_data["N_NUMERICAL"] + - setup_data["N_CATEGORICAL"]) - ) + MaternKernel( + nu=2.5, + ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]), + ) + ), + ScaleKernel( + CategoricalKernel( + ard_num_dims=setup_data["N_CATEGORICAL"], + active_dims=range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ), + ) ), ScaleKernel( - BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, - active_dims=(train_x.shape[1] - 1,))) + BoTorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(train_x.shape[1] - 1,), + ) + ), ] # Create the GP model - gp = SingleTaskGP(train_X=train_x, train_Y=train_y, - covar_module=AdditiveKernel(*kernels)) + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels) + ) # Fit the GP mll = ExactMarginalLogLikelihood(gp.likelihood, gp) @@ -119,24 +139,36 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None: # Define the kernels kernels = [ - ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=setup_data["N_NUMERICAL"], - active_dims=range(setup_data["N_NUMERICAL"]))), + ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=setup_data["N_NUMERICAL"], + active_dims=range(setup_data["N_NUMERICAL"]), + ) + ), ScaleKernel( CategoricalKernel( ard_num_dims=setup_data["N_CATEGORICAL"], - active_dims=range(setup_data["N_NUMERICAL"], - setup_data["N_NUMERICAL"] + - setup_data["N_CATEGORICAL"]) + active_dims=range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ), ) ), ScaleKernel( - BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, - active_dims=(train_x.shape[1] - 1,))) + BoTorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(train_x.shape[1] - 1,), + ) + ), ] # Create the GP model - gp = SingleTaskGP(train_X=train_x, train_Y=train_y, - covar_module=AdditiveKernel(*kernels)) + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels) + ) # Fit the GP mll = ExactMarginalLogLikelihood(gp.likelihood, gp) @@ -151,21 +183,30 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None: ) # Define bounds for optimization - bounds = torch.tensor([ - [0.0] * setup_data["N_NUMERICAL"] + [0.0] * setup_data["N_CATEGORICAL"] + [ - -1.0] * setup_data["N_GRAPH"], - [1.0] * setup_data["N_NUMERICAL"] + [ - float(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"] - 1)] * setup_data[ - "N_CATEGORICAL"] + [len(train_x) - 1] * setup_data["N_GRAPH"], - ]) + bounds = torch.tensor( + [ + [0.0] * setup_data["N_NUMERICAL"] + + [0.0] * setup_data["N_CATEGORICAL"] + + [-1.0] * setup_data["N_GRAPH"], + [1.0] * setup_data["N_NUMERICAL"] + + [float(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"] - 1)] + * setup_data["N_CATEGORICAL"] + + [len(train_x) - 1] * setup_data["N_GRAPH"], + ] + ) # Define fixed categorical features - cats_per_column = {i: list(range(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"])) - for i in range(setup_data["N_NUMERICAL"], - setup_data["N_NUMERICAL"] + setup_data[ - "N_CATEGORICAL"])} - fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in - product(*cats_per_column.values())] + cats_per_column = { + i: list(range(setup_data["N_CATEGORICAL_VALUES_PER_CATEGORY"])) + for i in range( + setup_data["N_NUMERICAL"], + setup_data["N_NUMERICAL"] + setup_data["N_CATEGORICAL"], + ) + } + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo, strict=False)) + for combo in product(*cats_per_column.values()) + ] # Optimize the acquisition function best_candidate, best_graph, best_score = optimize_acqf_graph( @@ -180,16 +221,19 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None: ) # Assertions for the acquisition function optimization - assert isinstance(best_candidate, - torch.Tensor), "Best candidate should be a tensor" - assert best_candidate.shape == (1, train_x.shape[1] - 1), \ + assert isinstance(best_candidate, torch.Tensor), ( + "Best candidate should be a tensor" + ) + assert best_candidate.shape == (1, train_x.shape[1] - 1), ( "Best candidate should have the correct shape (excluding the graph index)" + ) assert isinstance(best_graph, nx.Graph), "Best graph should be a NetworkX graph" assert isinstance(best_score, float), "Best score should be a float" # Ensure the best candidate does not contain the graph index column - assert best_candidate.shape[1] == train_x.shape[1] - 1, \ + assert best_candidate.shape[1] == train_x.shape[1] - 1, ( "Best candidate should not include the graph index column" + ) def test_graph_sampling(self, setup_data: dict) -> None: """Test the graph sampling functionality.""" @@ -200,12 +244,15 @@ def test_graph_sampling(self, setup_data: dict) -> None: sampled_graphs = sample_graphs(train_graphs, num_samples=num_samples) # Basic checks - assert len(sampled_graphs) == num_samples, \ + assert len(sampled_graphs) == num_samples, ( f"Expected {num_samples} sampled graphs, got {len(sampled_graphs)}" - assert all(isinstance(graph, nx.Graph) for graph in sampled_graphs), \ + ) + assert all(isinstance(graph, nx.Graph) for graph in sampled_graphs), ( "All sampled graphs should be NetworkX graphs" - assert all(nx.is_connected(graph) for graph in sampled_graphs), \ + ) + assert all(nx.is_connected(graph) for graph in sampled_graphs), ( "All sampled graphs should be connected" + ) def test_min_max_scaling(self, setup_data: dict) -> None: """Test the min-max scaling utility.""" @@ -217,8 +264,9 @@ def test_min_max_scaling(self, setup_data: dict) -> None: # Assertions for min-max scaling assert torch.all(scaled_train_x >= 0), "Scaled values should be >= 0" assert torch.all(scaled_train_x <= 1), "Scaled values should be <= 1" - assert scaled_train_x.shape == train_x.shape, \ + assert scaled_train_x.shape == train_x.shape, ( "Scaled data should have the same shape as the input data" + ) # Check that the scaling is correct for i in range(train_x.shape[1]): @@ -226,8 +274,9 @@ def test_min_max_scaling(self, setup_data: dict) -> None: col_max = torch.max(train_x[:, i]) if col_min != col_max: # Avoid division by zero expected_scaled_col = (train_x[:, i] - col_min) / (col_max - col_min) - assert torch.allclose(scaled_train_x[:, i], expected_scaled_col), \ + assert torch.allclose(scaled_train_x[:, i], expected_scaled_col), ( f"Scaling is incorrect for column {i}" + ) def test_set_graph_lookup(self, setup_data: dict) -> None: """Test the set_graph_lookup context manager.""" @@ -235,8 +284,9 @@ def test_set_graph_lookup(self, setup_data: dict) -> None: test_graphs = setup_data["test_graphs"] # Define the kernel - kernel = BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True, - active_dims=(0,)) + kernel = BoTorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, active_dims=(0,) + ) # Use the context manager to temporarily set the graph lookup with set_graph_lookup(kernel, test_graphs, append=True): diff --git a/tests/test_graphs/test_torch_wl_kernel.py b/tests/test_graphs/test_torch_wl_kernel.py index a30de8a30..3e2f3c1f4 100644 --- a/tests/test_graphs/test_torch_wl_kernel.py +++ b/tests/test_graphs/test_torch_wl_kernel.py @@ -75,16 +75,19 @@ def test_wl_kernel_against_grakel( ) -> None: for graph_set in random_graphs_sets: adjacency_matrices, label_tensors = graphs_to_tensors( - graph_set, device=self.device) + graph_set, device=self.device + ) # Initialize Torch WL Kernel torch_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) - torch_kernel_matrix = torch_kernel(adjacency_matrices, - label_tensors).cpu().numpy() + torch_kernel_matrix = ( + torch_kernel(adjacency_matrices, label_tensors).cpu().numpy() + ) # Initialize GraKel WL Kernel grakel_graphs = list( - graph_from_networkx(graph_set, node_labels_tag="label", as_Graph=True)) + graph_from_networkx(graph_set, node_labels_tag="label", as_Graph=True) + ) grakel_kernel = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) grakel_kernel_matrix = grakel_kernel.fit_transform(grakel_graphs) @@ -94,7 +97,7 @@ def test_wl_kernel_against_grakel( grakel_kernel_matrix, rtol=1e-5, atol=1e-8, - err_msg=f"Kernel matrices differ for graph={graph_set}, n_iter={n_iter}" + err_msg=f"Kernel matrices differ for graph={graph_set}, n_iter={n_iter}", ) def test_empty_graph(self) -> None: @@ -102,8 +105,9 @@ def test_empty_graph(self) -> None: G_empty.add_node(0) G_empty.nodes[0]["label"] = "0" - adjacency_matrices, label_tensors = graphs_to_tensors([G_empty], - device=self.device) + adjacency_matrices, label_tensors = graphs_to_tensors( + [G_empty], device=self.device + ) # Initialize kernel and compute kernel = TorchWLKernel(n_iter=3, normalize=True) @@ -116,8 +120,9 @@ def test_empty_graph(self) -> None: def test_invalid_input(self) -> None: wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - with pytest.raises(ValueError, - match="Mismatch between adjacency matrices and label tensors"): + with pytest.raises( + ValueError, match="Mismatch between adjacency matrices and label tensors" + ): wl_kernel([], [torch.tensor([0])]) def test_kernel_on_single_node_graph(self) -> None: @@ -125,8 +130,9 @@ def test_kernel_on_single_node_graph(self) -> None: G_single.add_node(0) G_single.nodes[0]["label"] = "0" - adjacency_matrices, label_tensors = graphs_to_tensors([G_single], - device=self.device) + adjacency_matrices, label_tensors = graphs_to_tensors( + [G_single], device=self.device + ) wl_kernel = TorchWLKernel(n_iter=3, normalize=True) K = wl_kernel(adjacency_matrices, label_tensors) @@ -168,8 +174,9 @@ def test_wl_kernel_with_empty_graph_and_reordered_edges( K = wl_kernel(adjacency_matrices, label_tensors) assert K.shape == (3, 3), "Kernel matrix shape is incorrect" - assert torch.allclose(K[1, 1], K[2, 2]), \ + assert torch.allclose(K[1, 1], K[2, 2]), ( "Kernel value for original and reordered graphs should be the same" + ) @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) @pytest.mark.parametrize("normalize", [True, False]) @@ -184,8 +191,7 @@ def test_wl_kernel_with_different_node_labels( G_copy.nodes[node]["label"] = f"{prefix}{node}" graphs.append(G_copy) - adjacency_matrices, label_tensors = graphs_to_tensors(graphs, - device=self.device) + adjacency_matrices, label_tensors = graphs_to_tensors(graphs, device=self.device) wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) torch_kernel_matrix = wl_kernel(adjacency_matrices, label_tensors).cpu().numpy() @@ -199,7 +205,7 @@ def test_wl_kernel_with_different_node_labels( grakel_kernel_matrix, rtol=1e-5, atol=1e-8, - err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}" + err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}", ) def test_wl_kernel_with_same_node_labels( @@ -220,8 +226,7 @@ def test_wl_kernel_with_same_node_labels( G_copy.nodes[node]["label"] = "A" graphs.append(G_copy) - adjacency_matrices, label_tensors = graphs_to_tensors( - graphs, device=self.device) + adjacency_matrices, label_tensors = graphs_to_tensors(graphs, device=self.device) wl_kernel = TorchWLKernel(n_iter=3, normalize=True) K = wl_kernel(adjacency_matrices, label_tensors) @@ -231,13 +236,15 @@ def test_wl_kernel_with_same_node_labels( assert torch.allclose(K, K.T, atol=1e-4), "Kernel matrix is not symmetric" # Check diagonal elements are 1 (normalized self-similarity) - assert torch.allclose(torch.diag(K), torch.ones_like(torch.diag(K)), atol=1e-4), \ + assert torch.allclose(torch.diag(K), torch.ones_like(torch.diag(K)), atol=1e-4), ( "Diagonal elements should be 1.0" + ) # Check off-diagonal elements are less than 1 (different structures) off_diag_mask = ~torch.eye(K.shape[0], dtype=torch.bool, device=self.device) - assert torch.all(K[off_diag_mask] < 1.0), \ + assert torch.all(K[off_diag_mask] < 1.0), ( "Off-diagonal elements should be less than 1.0 for different structures" + ) # Check all elements are non-negative (valid kernel) assert torch.all(K >= 0), "Kernel values should be non-negative" From bd64fdc806767330ea50fa2711e91ada8364aaca Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Mon, 27 Jan 2025 08:34:42 +0100 Subject: [PATCH 06/50] fix: add grakel to dev dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a1ce332e0..3824542a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ dev = [ "mkdocs-literate-nav", "mike", "black", # This allows mkdocstrings to format signatures in the docs + "grakel==0.1.10", ] [tool.setuptools.packages.find] From 01d52ae28ca0de487db716eae70ee44949d1c03c Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 7 Feb 2025 18:19:28 +0100 Subject: [PATCH 07/50] yo --- graph.py | 931 ++++++++++++++++++++++++++++++++++++++++++++++++++ perf.py | 32 ++ test_graph.py | 160 +++++++++ 3 files changed, 1123 insertions(+) create mode 100644 graph.py create mode 100644 perf.py create mode 100644 test_graph.py diff --git a/graph.py b/graph.py new file mode 100644 index 000000000..e8bc8d2a2 --- /dev/null +++ b/graph.py @@ -0,0 +1,931 @@ +from __future__ import annotations + +import itertools +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing_extensions import assert_never + +import more_itertools +import networkx as nx +from torch import nn + +from neps.exceptions import NePSError + +if TYPE_CHECKING: + import numpy as np + + +class ParseError(NePSError): + pass + + +class ReLUConvBN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + super().__init__() + + self.kernel_size = kernel_size + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=1, + bias=False, + ), + nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self): + return self + + +class Leaf(NamedTuple): + symbol: str + op: Callable + + +class Container(NamedTuple): + symbol: str + children: list[Node] + op: Callable + + +class Passthrough(NamedTuple): + symbol: str + children: list[Node] + + +Node: TypeAlias = Container | Passthrough | Leaf + + +@dataclass +class Tree: + root: Container | Leaf + + nodes: dict[int, Node] + + children_ids_of: dict[int, list[int]] + parent_id_of: dict[int, int] + leafs: list[int] + + @classmethod + def from_node(cls, node: Node) -> Tree: + """Create a `Tree` from a node, where node is considered the root.""" + nodes: dict[int, Node] = {} + children_ids_of: dict[int, list[int]] = {} + parent_id_of: dict[int, int] = {} + + def _traverse(n: Node, parent_id: int | None = None) -> None: + node_id = id(n) + nodes[node_id] = n + + if parent_id is not None: + parent_id_of[node_id] = parent_id + children_ids_of[parent_id].append(node_id) + + match n: + case Leaf(): + pass + case Container(_, children, _) | Passthrough(_, children): + children_ids_of[node_id] = [] + for child in children: + _traverse(child, node_id) + case _: + assert_never(n) + + _traverse(node) + + # Validate node is a Container or Leaf + if not isinstance(node, Container | Leaf): + raise ValueError("Root node must be a Container or Leaf") + + return cls( + root=node, + nodes=nodes, + children_ids_of=children_ids_of, + parent_id_of=parent_id_of, + leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], + ) + + +@dataclass +class Grammar: + rules: dict[str, Terminal | NonTerminal] + + class Terminal(NamedTuple): + op: Callable + shared: bool = False + + class NonTerminal(NamedTuple): + choices: list[str] + op: Callable | None = None + shared: bool = False + + @classmethod + def from_dict( + cls, + grammar: dict[ + str, + Callable + | list[str] + | tuple[list[str], Callable] + | Grammar.Terminal + | Grammar.NonTerminal, + ], + ) -> Grammar: + rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} + for symbol, rule in grammar.items(): + match rule: + case Grammar.Terminal() | Grammar.NonTerminal(): + rules[symbol] = rule + case (choices, op) if isinstance(choices, list) and callable(op): + # > e.g. "S": (["A", "A B", "C"], op) + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) + + case choices if isinstance(choices, list): + # > e.g. "S": ["A", "A B", "C"] + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + + case op if callable(op): + # > e.g. "S": op + rules[symbol] = Grammar.Terminal(op, shared=False) + case _: + raise ValueError( + f"The rule for symbol {symbol} is not recognized. Should be" + " a list of of symbols, a callable or a tuple with both." + f"\n Got {rule}" + ) + + return Grammar(rules) + + +def sample_grammar( + symbol: str, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + variables = variables or {} + rule = grammar.rules.get(symbol) + if rule is None: + raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices, op): + chosen_children = rng.choice(choices).split(" ") + children = [ + sample_grammar(child_symbol, grammar, rng=rng, variables=variables) + for child_symbol in chosen_children + ] + if op is None: + node = Passthrough(symbol, children=children) + else: + node = Container(symbol, op=op, children=children) + case _: + assert_never(rule) + + if rule.shared: + variables[symbol] = node + + return node + + +def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: + # Find the unique root (a node with no incoming edges) + _root = next((n for n, d in graph.in_degree if d == 0), None) + if _root is None: + raise ValueError( + "Could not find a root in the given graph (a node with indegree 0)." + ) + + variables: dict[str, Node] = {} + + def _recurse(node_id: int) -> Node: + symbol = graph.nodes[node_id].get("label") + if symbol is None: + raise ValueError(f"Node {node_id} does not have a 'label' property.") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + rule = grammar.rules.get(symbol) + if rule is None: + raise ValueError( + f"Symbol '{symbol}' not found in grammar rules: {grammar.rules.keys()}" + ) + + # Based on the type of rule, construct the proper node + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices=_, op=op): + children = [_recurse(child_id) for child_id in graph.successors(node_id)] + if op is None: + node = Passthrough(symbol, children) + else: + node = Container(symbol, children, op) + case _: + raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") + + if rule.shared: + variables[symbol] = node + + return node + + # Start with the root node + return _recurse(_root) + + +def mutate_leaf_parents( + root: Node, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + """Mutate a node, returning a different possibility for it.""" + if isinstance(root, Leaf): + raise ValueError(f"Can't mutate `Leaf`: {root}") + variables = variables or {} + tree: Tree = Tree.from_node(node=root) + + # Note, we can have duplicates here, that's fine, we want to weight those + # parents with many leafs more heavily... TODO: Maybe? + parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] + + chosen_node_id: int = rng.choice(parents) + chosen_node: Node = tree.nodes[chosen_node_id] + + match chosen_node: + case Passthrough() | Container(): + new_subnode = sample_grammar( + chosen_node.symbol, + grammar, + rng=rng, + # NOTE: subfunction will update variables dict + # with any instantiated `variables` if it doesn't + # exist already in the passed in `variables` + variables=variables, + ) + case Leaf(): + raise ValueError("don't pass leafs") + case _: + assert_never(chosen_node) + + def _build(n: Node): + # If we find the node to replace, replace it. + if id(n) == chosen_node_id: + return new_subnode + + # It may be the case that `sample_grammar` above populated + # `variables`, replacing one of the shared nodes with something + # new. In that case, we want to use the new sampled value wherever + # we encounter that symbol. + shared_node = variables.get(n.symbol) + if shared_node is not None: + return shared_node + + # Otherwise, we just rebuild as needed + match n: + case Leaf(): + return n + case Container(symbol, children, op): + return Container(symbol, children=[_build(c) for c in children], op=op) + case Passthrough(symbol, children): + return Passthrough(symbol, children=[_build(c) for c in children]) + case _: + assert_never(n) + + return _build(root) + + +def mutate_many( + node: Node, grammar: Grammar, *, rng: np.random.Generator +) -> Iterator[Node]: ... + + +# TODO: This has issues as we are using id's, while we may have heirarchical components +# which share the same id. +def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: + nodes: list[tuple[int, dict]] = [] + edges: list[tuple[int, int]] = [] + id_generator: Iterator[int] = itertools.count() + + def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: + node_id = next(id_generator) + match node: + # Atoms are just a node with an edge to its parent + case Leaf(symbol): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + # If we have a passthrough and shouldn't include them, we simply + # forward on the `parent_id` we recieved to the children + case Passthrough(_, children) if include_passthroughs is False: + for child in children: + _recurse_fill_lists(child, parent_id=parent_id) + + # Containers are a node in the graph, with edges to its + # children (direct, or through passthrough) + case Container(symbol, children, _) | Passthrough(symbol, children): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + for child in children: + _recurse_fill_lists(child, parent_id=node_id) + + case _: + assert_never(root.kind) + + graph = nx.DiGraph() + root_id = next(id_generator) + match root: + case Leaf(): + nodes.append((root_id, {"label": root.symbol})) + case Passthrough(_, children) if include_passthroughs is False: + raise ValueError( + f"Can't create a graph starting from a `Passthrough` {root.symbol}, " + " unless `include_passthrough`" + ) + case Container(_, children, _) | Passthrough(_, children): + for child in children: + _recurse_fill_lists(child, parent_id=root_id) + case _: + assert_never(root) + + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, c in enumerate(string): + match c: + case "(": + bracket_stack.append(i) + case ")": + if len(bracket_stack) == 0: + raise ParseError( + f"Encountered mismatched brackets at position {i}" + f" in string '{string}'" + ) + bracket_start = bracket_stack.pop(-1) + bracket_pairs[bracket_start] = i + case _: + continue + + if len(bracket_stack) > 0: + raise ParseError( + "Encountered a mismatch in the number of brackets." + f"The bracket(s) at position {bracket_stack} were never closed" + f" in the string '{string}'" + ) + + variables: dict[str, Node] = {} + + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + symbol = "" + i = frm + while i <= to: # Use a while loop as we may jump ahead in the loop + c = string[i] + match c: + # Ignore whiespace + case " " | "\n" | "\t": + i += 1 + # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed + # out a symbol from the s(a). Should only occur after a ")" + case "," if symbol == "": + assert string[i - 1] == ")" + i += 1 + # If the last character of a substring ends in a comma, this + # is not a valid string. + case "," if i == to: + raise ParseError( + "Got a (sub)string terminating in a ','." + " The ',' indicates something should come after it." + f" {string[frm : to + 1]}" + ) + # Otherwise, it's a valid ',' with a symbol before it + case ",": + i += 1 + node_symbol = symbol + symbol = "" + + rule = grammar.rules.get(node_symbol) + if rule is None: + raise ParseError( + f"Symbol '{node_symbol}' not in grammar" + f" {grammar.rules.keys()}" + ) + + # We parse out the node, even if it's shared, as we need to ensure + # what we parse out would match whatever is in the shared variables. + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(): + raise ParseError( + f"`NonTerminal` '{node_symbol}' can not be followed" + " by a comma ',' as it contains children inside brackets" + " '()'" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + # If we encounter an open bracket with no preceeding token, + # then this is invalid + case "(" if symbol == "": + raise ParseError( + "Encountered an open brace '(' without any" + f" symbol parsed before it in string {string[frm : to + 1]} " + ) + # Open a new subtree + case "(": + assert i in bracket_pairs + + # Find out where we need to parse to get the children + bracket_start = i + bracket_end = bracket_pairs[bracket_start] + assert bracket_end <= to, f"{bracket_end=} > {to=}" + children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + + # Advance the tokenizer past the end of that bracket + i = bracket_end + 1 + + # Reset the symbol + node_symbol = symbol + symbol = "" + + # Build the node with it's children + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.NonTerminal(_, op): + if strict: + child_substring = " ".join( + child.symbol for child in children + ) + if child_substring not in rule.choices: + substring = string[bracket_start : bracket_end + 1] + raise ParseError( + f"While {substring=} is parsable, the children" + f" '{child_substring}' is not one of the valid" + f" choices for '{node_symbol} : {rule.choices}." + " To allow this anyways, pass `strict=False` to" + " this call." + ) + + if op is None: + node = Passthrough(node_symbol, children) + else: + node = Container(node_symbol, children, op) + case Grammar.Terminal(op): + raise ParseError("Encountered a '(' after a Terminal.") + case None: + raise ParseError( + f"No associated rule with {node_symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case ")" if symbol == "": + # This occurs in repeated brackets and is fine + # > 's(s(a))' + i += 1 + continue + case ")": + # If we reached this bracket, just make sure the parsing algorithm + # is working correctly by checking we are indeed where we think + # we should be which is at `to` + assert i == to + i += 1 + + node_symbol = symbol + symbol = "" # This should be the end of the recursed call anywho + + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError("A ')' should never follow a `NonTerminal`") + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case _: + i += 1 + symbol += c # Append to current token + + # This occurs when we did not encounter any special characters + # like `,`, `(` or `)`. + # I'm pretty sure the only case this can happen is if we have something + # like the string `"b"`, i.e. just a `Leaf` + if symbol != "": + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError( + "Did not expected to have `NonTerminal` without" + " special characters '(', ')' or ','" + ) + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + yield node + + itr = _parse(frm=0, to=len(string) - 1) + root_token = next(itr, None) + second_token = next(itr, None) + if second_token is not None: + raise ParseError( + "If getting the root as a `Leaf`, then we should have no proceeding tokens." + ) + + match root_token: + case Leaf() | Container(): + return root_token + case Passthrough(): + raise ParseError("Should not have recieved a `Passthrough` as the root token") + case None: + raise ParseError(f"No token was parsed, was the string empty? {string=}") + case _: + assert_never(root_token) + + +# NOTE: Not sure we want this as a standalone function, but it serves to show some logic +def is_valid( + grammar: Grammar, + node: Node, + *, + already_shared: set[str] | None = None, +) -> bool: + rule = grammar.rules.get(node.symbol) + if rule is None: + raise ValueError( + f"Node has unknown symbol {node.symbol}, valid symbols are" + f" {grammar.rules.keys()}" + ) + + # We should never encounter a situtation where we have some nesting of shared nodes, + # for example, consider the following, where L1 is shared. + # L1 -> x -> ... -> L1 -> x -> ... + already_shared = already_shared or set() + if rule.shared and node.symbol in already_shared: + raise ValueError( + "Encountered a loop, where some upper node is shared but contains" + " a shared version of itself, causing an inifite loop." + ) + + match node: + case Leaf(symbol): + return symbol in grammar.rules + case Container(symbol, children, _) | Passthrough(symbol, children): + s = " ".join(child.symbol for child in children) + + match rule: + case Grammar.Terminal(_): + return s in grammar.rules and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case Grammar.NonTerminal(choices, _): + return s in choices and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case _: + assert_never(rule) + case _: + assert_never(node) + + +# TODO: Optimization, we don't need to recompute shared substrings. +# This is likely not worth it unless we have really deep trees +def to_string(node: Node) -> str: + match node: + case Leaf(symbol): + return symbol + case Passthrough(symbol, children) | Container(symbol, children): + return f"{symbol}({', '.join(to_string(c) for c in children)})" + case _: + assert_never(node) + + +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + stack.extend(reversed(children)) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + queue.extend(children) + + +# TODO: The variables thing can mess up the max depth +def bfs_grammar( + grammar: Grammar, + symbol: str, + *, + max_depth: int, + current_depth: int = 0, + variables: dict[str, Node] | None = None, + rng_shuffle: np.random.Generator | None = None, +) -> Iterator[Node]: + if current_depth > max_depth: + return + + variables = variables or {} + shared_node = variables.get(symbol) + if shared_node is not None: + yield shared_node + return # TODO: check + + nxt_depth = current_depth + 1 + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + if rule.shared: + variables[symbol] = node + yield node + case Grammar.NonTerminal(choices, op): + for choice in choices: + children = choice.split(" ") + child_expansions: list[Iterator] = [ + bfs_grammar( + grammar, + child_symbol, + max_depth=max_depth, + current_depth=nxt_depth, + rng_shuffle=rng_shuffle, + variables=variables, + ) + for child_symbol in children + ] + + if rng_shuffle: + # This works correctly with python lists, but typing for numpy is off + rng_shuffle.shuffle(child_expansions) # type: ignore + + for possible in itertools.product(*child_expansions): + if op is None: + node = Passthrough(symbol, children=list(possible)) + else: + node = Container(symbol, op=op, children=list(possible)) + + if rule.shared: + variables[symbol] = node + + yield node + case None: + raise ValueError( + f"Could not find symbol {symbol} in table with keys{grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + +def to_model(node: Node) -> Any: + def _build(_n: Node) -> Iterator[Any]: + match _n: + case Leaf(_, op): + yield op() + case Container(_, children, op): + # The problem is that each child could be either: + # * A single 'thing', in the case of Leaf or Container + # * Multiple things, in case it's a passthrough + # Hence we flatten them out into a single big children itr + flat_children = more_itertools.collapse( + _build(child) for child in children + ) + import rich + + rich.print(flat_children) + yield op(*flat_children) + case Passthrough(_, children): + yield from (_build(child) for child in children) + case _: + assert_never(node) + + match node: + case Leaf() | Container(): + itr = _build(node) + obj = next(itr, None) + assert obj is not None, "Should have recieved at least one object" + assert next(itr, None) is None, "Should not have recieved two objects" + return obj + case Passthrough(symbol): + raise ValueError(f"Can not call build on a `Passthrough` {symbol}") + case _: + assert_never(node) + + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +# https://stackoverflow.com/a/29597209 +def hierarchy_pos( + G: nx.DiGraph, + root: int, + width: float = 1.0, + vert_gap: float = 0.2, + vert_loc: float = 0, + xcenter: float = 0.5, +) -> dict[int, tuple[float, float]]: + """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + Licensed under Creative Commons Attribution-Share Alike. + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + def _hierarchy_pos( + G, + root, + width=1.0, + vert_gap=0.2, + vert_loc: float = 0, + xcenter=0.5, + pos: dict[int, tuple[float, float]] | None = None, + parent=None, + ) -> dict[int, tuple[float, float]]: + """See hierarchy_pos docstring for most arguments. + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) diff --git a/perf.py b/perf.py new file mode 100644 index 000000000..63e37062e --- /dev/null +++ b/perf.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from functools import partial + +import numpy as np +from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from torch import nn + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +if __name__ == "__main__": + sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) diff --git a/test_graph.py b/test_graph.py new file mode 100644 index 000000000..625383f23 --- /dev/null +++ b/test_graph.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from graph import ( + Container, + Grammar, + Leaf, + Node, + ParseError, + Passthrough, + parse, + to_model, + to_string, +) + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +grammar_2 = Grammar.from_dict( + { + "L1": (["L2 L2 L3"], join), + "L2": Grammar.NonTerminal(["L3"], join, shared=True), + "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), + "a": T("a"), + "b": T("a"), + } +) + + +@pytest.mark.parametrize( + ("grammar", "string", "built", "node"), + [ + (grammar_1, "a", "a", Leaf("a", T("a"))), + (grammar_1, "b", "b", Leaf("b", T("b"))), + ( + grammar_1, + "s(a)", + "[a]", + Container("s", op=join, children=[Leaf("a", T("a"))]), + ), + ( + grammar_1, + "s(p(a, b))", + "[ab]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(a, b), p(s(a)))", + "[ab[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + Passthrough( + "p", + children=[Container("s", children=[Leaf("a", T("a"))], op=join)], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(s(a)))", + "[[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[ + Container( + "s", + children=[Leaf("a", T("a"))], + op=join, + ) + ], + ), + ], + op=join, + ), + ), + ], +) +def test_string_serialization_and_deserialization_correct( + grammar: Grammar, + string: str, + built: str, + node: Node, +) -> None: + # Test parsing + parsed = parse(grammar, string) + assert parsed == node + + # Test serialization + serialized_again = to_string(parsed) + assert serialized_again == string + + # Test building + assert to_model(parsed) == built + + +@pytest.mark.parametrize( + ("grammar", "string"), + [ + (grammar_1, "c"), + (grammar_1, ""), + (grammar_1, "s(a"), + (grammar_1, "p(a, b)"), + (grammar_1, "("), + (grammar_1, "s(a))"), + (grammar_1, "s((a)"), + (grammar_1, "s("), + (grammar_1, "s)"), + (grammar_1, "a, a"), + (grammar_1, "a,"), + (grammar_1, "s, s"), + # Invalid due to shared rule but not sharing values + (grammar_2, "L1(L2(L3(a)), L2(L3(a)), L3(b))"), + ], +) +def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: + with pytest.raises(ParseError): + parse(grammar, string) From 92a7b2606097fb51ba5433466a3333bacde7838c Mon Sep 17 00:00:00 2001 From: Timur Carstensen Date: Fri, 7 Feb 2025 18:56:20 +0100 Subject: [PATCH 08/50] chore: perf testing --- graph.py | 4 ++-- perf.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/graph.py b/graph.py index e8bc8d2a2..b366a9ae4 100644 --- a/graph.py +++ b/graph.py @@ -808,9 +808,9 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - import rich + # import rich - rich.print(flat_children) + # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) diff --git a/perf.py b/perf.py index 63e37062e..901b22f9c 100644 --- a/perf.py +++ b/perf.py @@ -3,13 +3,22 @@ from functools import partial import numpy as np -from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from graph import ( + Grammar, + Identity, + Node, + ReLUConvBN, + parse, + sample_grammar, + to_nxgraph, + to_string, +) from torch import nn structure = { "S": ( Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], + ["C", "reluconvbn", "S", "S C", "O O O", "S S O O O O O O"], nn.Sequential, ) ), @@ -29,4 +38,27 @@ if __name__ == "__main__": - sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) + import time + + import rich + + grammar = Grammar.from_dict(structure) + rng = np.random.default_rng() + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + graph = to_nxgraph(sample) + # model = to_model(sample) + + t0 = time.perf_counter() + samples = 10000 + + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + string = to_string(sample) + parse(string=string, grammar=grammar) + # graph = to_nxgraph(sample) + # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) + # model = to_model(sample) + + t1 = time.perf_counter() + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") + rich.print(f"duration for {samples} samples: {t1 - t0}s ") From 0155e7eea7a026205b7443c2b0cd696f60022720 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Sun, 9 Feb 2025 23:43:09 +0100 Subject: [PATCH 09/50] optimizations on parsing and test --- graph.py | 596 ++++++++++++++++++++++++++++++++++---------------- test_graph.py | 121 ++++++++++ 2 files changed, 533 insertions(+), 184 deletions(-) diff --git a/graph.py b/graph.py index b366a9ae4..e440a2896 100644 --- a/graph.py +++ b/graph.py @@ -2,25 +2,49 @@ import itertools from collections.abc import Callable, Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never import more_itertools import networkx as nx +import numpy as np from torch import nn from neps.exceptions import NePSError -if TYPE_CHECKING: - import numpy as np - class ParseError(NePSError): pass +# OPTIM: Calling `np.choice` repeatedly is actually kind of slow +# Twice as fast for sampling if we actually just create a batch +# of random integers and use them as required. +@dataclass +class BufferedRandIntStream: + rng: np.random.Generator + buffer_size: int = 51 + _cur_ix: int = 2 + + MAX_INT: ClassVar[int] = np.iinfo(np.int64).max + _nums: list[int] = field(default_factory=list) + + def next(self, n: int) -> int: + if self._cur_ix >= len(self._nums): + self._nums = self.rng.integers( + self.MAX_INT, size=self.buffer_size, dtype=np.int64 + ).tolist() + + self._cur_ix = 1 + + i = self._nums[self._cur_ix] % n + + self._cur_ix += 2 + return i + + class ReLUConvBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super().__init__() @@ -28,16 +52,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding): self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv2d( + nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=1, + dilation=2, bias=False, ), - nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), ) def forward(self, x): @@ -52,88 +76,91 @@ def forward(self): return self +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + stack.extend(reversed(children)) + case _: + assert_never(nxt) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + queue.extend(children) + case _: + assert_never(nxt) + + class Leaf(NamedTuple): symbol: str op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Container(NamedTuple): symbol: str children: list[Node] op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Passthrough(NamedTuple): symbol: str children: list[Node] + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node -Node: TypeAlias = Container | Passthrough | Leaf - - -@dataclass -class Tree: - root: Container | Leaf - - nodes: dict[int, Node] - - children_ids_of: dict[int, list[int]] - parent_id_of: dict[int, int] - leafs: list[int] - @classmethod - def from_node(cls, node: Node) -> Tree: - """Create a `Tree` from a node, where node is considered the root.""" - nodes: dict[int, Node] = {} - children_ids_of: dict[int, list[int]] = {} - parent_id_of: dict[int, int] = {} - - def _traverse(n: Node, parent_id: int | None = None) -> None: - node_id = id(n) - nodes[node_id] = n - - if parent_id is not None: - parent_id_of[node_id] = parent_id - children_ids_of[parent_id].append(node_id) - - match n: - case Leaf(): - pass - case Container(_, children, _) | Passthrough(_, children): - children_ids_of[node_id] = [] - for child in children: - _traverse(child, node_id) - case _: - assert_never(n) - - _traverse(node) - - # Validate node is a Container or Leaf - if not isinstance(node, Container | Leaf): - raise ValueError("Root node must be a Container or Leaf") - - return cls( - root=node, - nodes=nodes, - children_ids_of=children_ids_of, - parent_id_of=parent_id_of, - leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], - ) +Node: TypeAlias = Container | Passthrough | Leaf @dataclass class Grammar: rules: dict[str, Terminal | NonTerminal] + _shared: dict[str, NonTerminal] = field(init=False) + _leafs: dict[str, Leaf] = field(init=False) class Terminal(NamedTuple): op: Callable - shared: bool = False class NonTerminal(NamedTuple): choices: list[str] op: Callable | None = None shared: bool = False + def __post_init__(self) -> None: + self._shared = { + s: r + for s, r in self.rules.items() + if isinstance(r, Grammar.NonTerminal) and r.shared + } + self._leafs = { + s: Leaf(s, r.op) + for s, r in self.rules.items() + if isinstance(r, Grammar.Terminal) + } + @classmethod def from_dict( cls, @@ -167,11 +194,11 @@ def from_dict( if any(missing): raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") - rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) case op if callable(op): # > e.g. "S": op - rules[symbol] = Grammar.Terminal(op, shared=False) + rules[symbol] = Grammar.Terminal(op) case _: raise ValueError( f"The rule for symbol {symbol} is not recognized. Should be" @@ -186,23 +213,28 @@ def sample_grammar( symbol: str, grammar: Grammar, *, - rng: np.random.Generator, + rng: np.random.Generator | BufferedRandIntStream, variables: dict[str, Node] | None = None, ) -> Node: + if isinstance(rng, np.random.Generator): + rng = BufferedRandIntStream(rng=rng) + variables = variables or {} rule = grammar.rules.get(symbol) if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - match rule: - case Grammar.Terminal(op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(choices, op): - chosen_children = rng.choice(choices).split(" ") + case Grammar.Terminal(): + return grammar._leafs[symbol] + case Grammar.NonTerminal(choices=choices, op=op): + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + i = rng.next(len(choices)) + choice = choices[i] + chosen_children = choice.split(" ") children = [ sample_grammar(child_symbol, grammar, rng=rng, variables=variables) for child_symbol in chosen_children @@ -211,13 +243,13 @@ def sample_grammar( node = Passthrough(symbol, children=children) else: node = Container(symbol, op=op, children=children) - case _: - assert_never(rule) - if rule.shared: - variables[symbol] = node + if rule.shared: + variables[symbol] = node - return node + return node + case _: + assert_never(rule) def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: @@ -225,7 +257,7 @@ def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: _root = next((n for n, d in graph.in_degree if d == 0), None) if _root is None: raise ValueError( - "Could not find a root in the given graph (a node with indegree 0)." + "Could not find a root in the given graph (a node with indegree 1)." ) variables: dict[str, Node] = {} @@ -235,10 +267,6 @@ def _recurse(node_id: int) -> Node: if symbol is None: raise ValueError(f"Node {node_id} does not have a 'label' property.") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - rule = grammar.rules.get(symbol) if rule is None: raise ValueError( @@ -249,18 +277,21 @@ def _recurse(node_id: int) -> Node: match rule: case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(choices=_, op=op): + case Grammar.NonTerminal(op=op): + if (shared_node := variables.get(symbol)) is not None: + return shared_node + children = [_recurse(child_id) for child_id in graph.successors(node_id)] - if op is None: - node = Passthrough(symbol, children) - else: - node = Container(symbol, children, op) + node = ( + Passthrough(symbol, children) + if op is None + else Container(symbol, children, op) + ) + if rule.shared: + variables[symbol] = node case _: raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") - if rule.shared: - variables[symbol] = node - return node # Start with the root node @@ -278,14 +309,29 @@ def mutate_leaf_parents( if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") variables = variables or {} - tree: Tree = Tree.from_node(node=root) + + parents: dict[int, Node] = {} + leaf_parents: list[Node] = [] + + def _fill(n: Node, *, parent: Node) -> None: + node_id = id(n) + parents[node_id] = parent + match n: + case Leaf(): + leaf_parents.append(parent) + case Passthrough(_, children) | Container(_, children): + for child in children: + _fill(child, parent=parent) + case _: + assert_never(n) + + for child in root.children: + _fill(child, parent=root) # Note, we can have duplicates here, that's fine, we want to weight those # parents with many leafs more heavily... TODO: Maybe? - parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] - - chosen_node_id: int = rng.choice(parents) - chosen_node: Node = tree.nodes[chosen_node_id] + chosen_node: Node = rng.choice(leaf_parents) # type: ignore + chosen_node_id = id(chosen_node) match chosen_node: case Passthrough() | Container(): @@ -335,8 +381,6 @@ def mutate_many( ) -> Iterator[Node]: ... -# TODO: This has issues as we are using id's, while we may have heirarchical components -# which share the same id. def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: nodes: list[tuple[int, dict]] = [] edges: list[tuple[int, int]] = [] @@ -370,9 +414,10 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: graph = nx.DiGraph() root_id = next(id_generator) + nodes.append((root_id, {"label": root.symbol})) match root: case Leaf(): - nodes.append((root_id, {"label": root.symbol})) + pass case Passthrough(_, children) if include_passthroughs is False: raise ValueError( f"Can't create a graph starting from a `Passthrough` {root.symbol}, " @@ -389,7 +434,7 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: +def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: bracket_stack: list[int] = [] bracket_pairs: dict[int, int] = {} for i, c in enumerate(string): @@ -397,17 +442,17 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: case "(": bracket_stack.append(i) case ")": - if len(bracket_stack) == 0: + if len(bracket_stack) == 1: raise ParseError( f"Encountered mismatched brackets at position {i}" f" in string '{string}'" ) - bracket_start = bracket_stack.pop(-1) + bracket_start = bracket_stack.pop(0) bracket_pairs[bracket_start] = i case _: continue - if len(bracket_stack) > 0: + if len(bracket_stack) > 1: raise ParseError( "Encountered a mismatch in the number of brackets." f"The bracket(s) at position {bracket_stack} were never closed" @@ -416,31 +461,30 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: variables: dict[str, Node] = {} - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 symbol = "" i = frm while i <= to: # Use a while loop as we may jump ahead in the loop c = string[i] match c: # Ignore whiespace - case " " | "\n" | "\t": - i += 1 + case _ if c in (" \n\t"): + i += 2 # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed # out a symbol from the s(a). Should only occur after a ")" case "," if symbol == "": - assert string[i - 1] == ")" - i += 1 + i += 2 # If the last character of a substring ends in a comma, this # is not a valid string. case "," if i == to: raise ParseError( "Got a (sub)string terminating in a ','." " The ',' indicates something should come after it." - f" {string[frm : to + 1]}" + f" {string[frm : to + 2]}" ) # Otherwise, it's a valid ',' with a symbol before it case ",": - i += 1 + i += 2 node_symbol = symbol symbol = "" @@ -454,7 +498,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # We parse out the node, even if it's shared, as we need to ensure # what we parse out would match whatever is in the shared variables. match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) case Grammar.NonTerminal(): raise ParseError( @@ -486,20 +530,17 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case "(" if symbol == "": raise ParseError( "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 1]} " + f" symbol parsed before it in string {string[frm : to + 2]} " ) # Open a new subtree case "(": - assert i in bracket_pairs - # Find out where we need to parse to get the children bracket_start = i bracket_end = bracket_pairs[bracket_start] - assert bracket_end <= to, f"{bracket_end=} > {to=}" - children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + children = list(_parse(frm=bracket_start + 2, to=bracket_end)) # Advance the tokenizer past the end of that bracket - i = bracket_end + 1 + i = bracket_end + 2 # Reset the symbol node_symbol = symbol @@ -508,13 +549,13 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # Build the node with it's children rule = grammar.rules.get(node_symbol) match rule: - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): if strict: child_substring = " ".join( - child.symbol for child in children + [child.symbol for child in children] ) if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 1] + substring = string[bracket_start : bracket_end + 2] raise ParseError( f"While {substring=} is parsable, the children" f" '{child_substring}' is not one of the valid" @@ -527,7 +568,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 node = Passthrough(node_symbol, children) else: node = Container(node_symbol, children, op) - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): raise ParseError("Encountered a '(' after a Terminal.") case None: raise ParseError( @@ -556,23 +597,22 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case ")" if symbol == "": # This occurs in repeated brackets and is fine # > 's(s(a))' - i += 1 + i += 2 continue case ")": # If we reached this bracket, just make sure the parsing algorithm # is working correctly by checking we are indeed where we think # we should be which is at `to` - assert i == to - i += 1 + i += 2 node_symbol = symbol symbol = "" # This should be the end of the recursed call anywho rule = grammar.rules.get(node_symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError("A ')' should never follow a `NonTerminal`") case None: raise ParseError( @@ -599,7 +639,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node case _: - i += 1 + i += 2 symbol += c # Append to current token # This occurs when we did not encounter any special characters @@ -609,9 +649,9 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 if symbol != "": rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError( "Did not expected to have `NonTerminal` without" " special characters '(', ')' or ','" @@ -626,7 +666,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node - itr = _parse(frm=0, to=len(string) - 1) + itr = _parse(frm=1, to=len(string) - 1) root_token = next(itr, None) second_token = next(itr, None) if second_token is not None: @@ -645,6 +685,218 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 assert_never(root_token) +def parse(grammar: Grammar, string: str) -> Node: + # Chunk up the str + string_tokens: list[str] = [] + brace_count = 0 + symbol = "" + for tok in string: + match tok: + case " ": + continue + case "(": + brace_count += 1 + if len(symbol) == 0: + raise ParseError( + f"Opening bracket '(' must be preceeded by symbol" + f" but was not.\n{string}" + ) + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ")": + brace_count -= 1 + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ",": + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case _: + symbol += tok + + if brace_count != 0: + raise ParseError( + f"Imbalanced braces, got {abs(brace_count)} too many" + f" {'(' if brace_count > 0 else ')'}." + ) + + if len(symbol) > 0: + string_tokens.append(symbol) + + # Convert to concrete tokens + tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] + for symbol in string_tokens: + if symbol in "(),": + tokens.append(symbol) # type: ignore + continue + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(): + tokens.append((symbol, grammar._leafs[symbol])) + case Grammar.NonTerminal(): + tokens.append((symbol, rule)) + case None: + raise ParseError( + f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" + f" a symbol in {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + # If we're being strict that shared elements must be the same, then + # we can do so more cheaply at the beginning by just comparing subtokens + # before we parse. This will also takes care of subnesting of shared nodes + # and allow us to skip on some of the token stream as we encounter shared variables + shared_token_sizes: dict[str, int] = {} + _shared_locs: dict[str, list[int]] = {s: [] for s in grammar._shared} + + # We figure out the substrings of where each shared symbol begings and ends + if _shared_locs: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, tok in enumerate(tokens): + match tok: + case ( + "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) + ): + continue + case ")": + start = bracket_stack.pop(-1) + bracket_pairs[start] = i + case "(": + bracket_stack.append(i) + case (symbol, Grammar.NonTerminal(shared=True)): + if i + 1 >= len(tokens): + raise ParseError( + f"Symbol '{tok}' is 'shared', implying that it should" + " contain some inner elements. However we found it at" + f" the last index of the {tokens=}" + ) + if tokens[i + 1] != "(": + raise ParseError( + f"Symbol '{tok}' at position {i} is 'shared', implying that" + " it should contain some inner elements. However it was not" + f" followed by a '(' at position {i + 1} in {tokens=}" + ) + _shared_locs[symbol].append(i) + + # If we have more than one occurence of a shared symbol, + # we validate their subtokens match + for symbol, symbol_positions in _shared_locs.items(): + first_pos, rest = symbol_positions[0], symbol_positions[1:] + + # Calculate the inner tokens and length + bracket_first_start = first_pos + 1 + bracket_first_end = bracket_pairs[bracket_first_start] + + inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] + shared_symbol_token_size = len(inner_tokens) + shared_token_sizes[symbol] = shared_symbol_token_size + + for symbol_start in rest: + # +2, skip symbol_start and skip opening bracket '(' + symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] + if symbol_tokens != inner_tokens: + raise ParseError( + f"Found mismatch in shared symbol '{symbol}'" + f" with {symbol=} starting at token `{symbol_start}`" + f" and the same symbol at token `{first_pos}` which has" + f" {inner_tokens=}.\n{tokens=}" + ) + + if len(tokens) == 0: + raise ParseError("Recieved an empty strng") + + match tokens[0]: + case (symbol, Leaf()): + if len(tokens) > 1: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `Terminal`, but was proceeded by more token." + f"\n{tokens=}" + ) + _, root = tokens[0] + case (symbol, Grammar.NonTerminal(op=op)): + if op is None: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NonTerminal` that is `passthrough`, i.e. it has no associated" + " operation and can not be the root." + ) + if len(tokens) < 4: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NoneTerminal`, but should have at least 3 more tokens" + " for a '(', 'child' and a closing ')'" + ) + + # NOTE: We don't care about shared here as we validate above that + # a shared variable can not contain itself, and there are no other + # symbols above or on the same level as this one (as it's the root). + # Hence we do not need to interact with `shared` here. + root = Container(symbol=symbol, children=[], op=op) + case "(" | ")" | ",": + raise ParseError("First token can not be a '(', ')' or a ','") + case rule: + assert_never(rule) + + if isinstance(root, Leaf): + return root + + variables: dict[str, Container | Passthrough] = {} + parent_stack: list[Container | Passthrough] = [] + current: Node = root + + token_stream = iter(tokens[1:]) + + for tok in token_stream: + match tok: + case ",": + parent_stack[-1].children.append(current) + case ")": + parent = parent_stack.pop() + parent.children.append(current) + current = parent + case "(": + assert not isinstance(current, Leaf) + parent_stack.append(current) + case (symbol, rule): + if isinstance(rule, Leaf): + current = rule + continue + + if rule.shared and (existing := variables.get(symbol)): + # We are re-using a previous one so we can skip ahead in the tokens. + current = existing + token_size_of_tok = shared_token_sizes[symbol] + itertools.islice(token_stream, token_size_of_tok) # Skips + continue + + if rule.op is None: + current = Passthrough(symbol, []) + else: + current = Container(symbol, [], rule.op) + + if rule.shared: + variables[symbol] = current + case _: + assert_never(tok) + + return current + + # NOTE: Not sure we want this as a standalone function, but it serves to show some logic def is_valid( grammar: Grammar, @@ -660,10 +912,14 @@ def is_valid( ) # We should never encounter a situtation where we have some nesting of shared nodes, - # for example, consider the following, where L1 is shared. - # L1 -> x -> ... -> L1 -> x -> ... + # for example, consider the following, where L2 is shared. + # L2 -> x -> ... -> L1 -> x -> ... already_shared = already_shared or set() - if rule.shared and node.symbol in already_shared: + if ( + isinstance(rule, Grammar.NonTerminal) + and rule.shared + and node.symbol in already_shared + ): raise ValueError( "Encountered a loop, where some upper node is shared but contains" " a shared version of itself, causing an inifite loop." @@ -695,6 +951,7 @@ def is_valid( # TODO: Optimization, we don't need to recompute shared substrings. # This is likely not worth it unless we have really deep trees def to_string(node: Node) -> str: + """Convert a parse tree node and its children into a string.""" match node: case Leaf(symbol): return symbol @@ -704,39 +961,13 @@ def to_string(node: Node) -> str: assert_never(node) -def dfs_node(node: Node) -> Iterator[Node]: - stack: list[Node] = [node] - while stack: - nxt = stack.pop(-1) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - stack.extend(reversed(children)) - - -def bfs_node(node: Node) -> Iterator[Node]: - queue: list[Node] = [node] - while queue: - nxt = queue.pop(0) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - queue.extend(children) - - # TODO: The variables thing can mess up the max depth -def bfs_grammar( +def bfs_grammar( # noqa: C901, D103 grammar: Grammar, symbol: str, *, max_depth: int, - current_depth: int = 0, + current_depth: int = 1, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -749,16 +980,14 @@ def bfs_grammar( yield shared_node return # TODO: check - nxt_depth = current_depth + 1 + nxt_depth = current_depth + 2 rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - if rule.shared: - variables[symbol] = node yield node - case Grammar.NonTerminal(choices, op): + case Grammar.NonTerminal(choices=choices, op=op): for choice in choices: children = choice.split(" ") child_expansions: list[Iterator] = [ @@ -796,6 +1025,8 @@ def bfs_grammar( def to_model(node: Node) -> Any: + """Convert a parse tree node and its children into some object it represents.""" + def _build(_n: Node) -> Iterator[Any]: match _n: case Leaf(_, op): @@ -808,9 +1039,6 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - # import rich - - # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) @@ -838,30 +1066,30 @@ def _build(_n: Node) -> Iterator[Any]: ) ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), - "O": ["3", "1", "id"], + "O": ["4", "1", "id"], "reluconvbn": partial( - ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ReLUConvBN, in_channels=4, out_channels=3, kernel_size=3, stride=1, padding=1 ), "id": Identity, - "3": partial( - nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + "4": partial( + nn.Conv3d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 ), - "1": partial( - nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + "2": partial( + nn.Conv3d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 ), } -# https://stackoverflow.com/a/29597209 +# https://stackoverflow.com/a/29597210 def hierarchy_pos( G: nx.DiGraph, root: int, - width: float = 1.0, - vert_gap: float = 0.2, - vert_loc: float = 0, - xcenter: float = 0.5, + width: float = 2.0, + vert_gap: float = 1.2, + vert_loc: float = 1, + xcenter: float = 1.5, ) -> dict[int, tuple[float, float]]: - """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. Licensed under Creative Commons Attribution-Share Alike. If the graph is a tree this will return the positions to plot this in a @@ -891,10 +1119,10 @@ def hierarchy_pos( def _hierarchy_pos( G, root, - width=1.0, - vert_gap=0.2, - vert_loc: float = 0, - xcenter=0.5, + width=2.0, + vert_gap=1.2, + vert_loc: float = 1, + xcenter=1.5, pos: dict[int, tuple[float, float]] | None = None, parent=None, ) -> dict[int, tuple[float, float]]: @@ -911,9 +1139,9 @@ def _hierarchy_pos( children = list(G.neighbors(root)) if not isinstance(G, nx.DiGraph) and parent is not None: children.remove(parent) - if len(children) != 0: + if len(children) != 1: dx = width / len(children) - nextx = xcenter - width / 2 - dx / 2 + nextx = xcenter - width / 3 - dx / 2 for child in children: nextx += dx pos = _hierarchy_pos( diff --git a/test_graph.py b/test_graph.py index 625383f23..4de247c60 100644 --- a/test_graph.py +++ b/test_graph.py @@ -12,6 +12,8 @@ Passthrough, parse, to_model, + to_node_from_graph, + to_nxgraph, to_string, ) @@ -135,6 +137,11 @@ def test_string_serialization_and_deserialization_correct( # Test building assert to_model(parsed) == built + # Test graph and back again + graph = to_nxgraph(parsed, include_passthroughs=True) + node_again = to_node_from_graph(graph, grammar) + assert parsed == node_again + @pytest.mark.parametrize( ("grammar", "string"), @@ -158,3 +165,117 @@ def test_string_serialization_and_deserialization_correct( def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: with pytest.raises(ParseError): parse(grammar, string) + + +def test_dfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.dfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # go down left depth first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + # go down right depth first + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" + + +def test_bfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.bfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # Second level first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + # Then 3rd level + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" From 2d30ed7ce16db0f874cb8151ea4f3675a52b56a3 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 00:47:24 +0100 Subject: [PATCH 10/50] fix weird numerics and new opt --- graph.py | 321 ++++++++----------------------------------------------- perf.py | 6 +- 2 files changed, 49 insertions(+), 278 deletions(-) diff --git a/graph.py b/graph.py index e440a2896..185433d1e 100644 --- a/graph.py +++ b/graph.py @@ -25,8 +25,8 @@ class ParseError(NePSError): @dataclass class BufferedRandIntStream: rng: np.random.Generator - buffer_size: int = 51 - _cur_ix: int = 2 + buffer_size: int = 50 + _cur_ix: int = 0 MAX_INT: ClassVar[int] = np.iinfo(np.int64).max _nums: list[int] = field(default_factory=list) @@ -37,11 +37,11 @@ def next(self, n: int) -> int: self.MAX_INT, size=self.buffer_size, dtype=np.int64 ).tolist() - self._cur_ix = 1 + self._cur_ix = 0 i = self._nums[self._cur_ix] % n - self._cur_ix += 2 + self._cur_ix += 1 return i @@ -224,33 +224,55 @@ def sample_grammar( if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + stack: list[Container | Passthrough] = [] match rule: case Grammar.Terminal(): return grammar._leafs[symbol] - case Grammar.NonTerminal(choices=choices, op=op): + case Grammar.NonTerminal(choices, op, shared): shared_node = variables.get(symbol) if shared_node is not None: return shared_node - i = rng.next(len(choices)) - choice = choices[i] - chosen_children = choice.split(" ") - children = [ - sample_grammar(child_symbol, grammar, rng=rng, variables=variables) - for child_symbol in chosen_children - ] - if op is None: - node = Passthrough(symbol, children=children) - else: - node = Container(symbol, op=op, children=children) - - if rule.shared: - variables[symbol] = node - - return node + i = rng.next(len(rule.choices)) + initial_sample = rule.choices[i] + children_symbols = initial_sample.split(" ") + root = Passthrough(symbol, []) if op is None else Container(symbol, [], op) + stack.append(root) case _: assert_never(rule) + while stack: + parent = stack.pop() + i = rng.next(len(choices)) + choice = choices[i] + children_symbols = choice.split(" ") + + for child_symbol in children_symbols: + rule = grammar.rules[child_symbol] + match rule: + case Grammar.Terminal(): + parent.children.append(grammar._leafs[child_symbol]) + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(child_symbol) + if shared_node is not None: + parent.children.append(shared_node) + continue + + sub_parent = ( + Passthrough(child_symbol, []) + if op is None + else Container(child_symbol, [], op) + ) + parent.children.append(sub_parent) + stack.append(sub_parent) + + if shared: + variables[child_symbol] = sub_parent + case _: + assert_never(rule) + + return root + def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: # Find the unique root (a node with no incoming edges) @@ -434,257 +456,6 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: - bracket_stack: list[int] = [] - bracket_pairs: dict[int, int] = {} - for i, c in enumerate(string): - match c: - case "(": - bracket_stack.append(i) - case ")": - if len(bracket_stack) == 1: - raise ParseError( - f"Encountered mismatched brackets at position {i}" - f" in string '{string}'" - ) - bracket_start = bracket_stack.pop(0) - bracket_pairs[bracket_start] = i - case _: - continue - - if len(bracket_stack) > 1: - raise ParseError( - "Encountered a mismatch in the number of brackets." - f"The bracket(s) at position {bracket_stack} were never closed" - f" in the string '{string}'" - ) - - variables: dict[str, Node] = {} - - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 - symbol = "" - i = frm - while i <= to: # Use a while loop as we may jump ahead in the loop - c = string[i] - match c: - # Ignore whiespace - case _ if c in (" \n\t"): - i += 2 - # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed - # out a symbol from the s(a). Should only occur after a ")" - case "," if symbol == "": - i += 2 - # If the last character of a substring ends in a comma, this - # is not a valid string. - case "," if i == to: - raise ParseError( - "Got a (sub)string terminating in a ','." - " The ',' indicates something should come after it." - f" {string[frm : to + 2]}" - ) - # Otherwise, it's a valid ',' with a symbol before it - case ",": - i += 2 - node_symbol = symbol - symbol = "" - - rule = grammar.rules.get(node_symbol) - if rule is None: - raise ParseError( - f"Symbol '{node_symbol}' not in grammar" - f" {grammar.rules.keys()}" - ) - - # We parse out the node, even if it's shared, as we need to ensure - # what we parse out would match whatever is in the shared variables. - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(): - raise ParseError( - f"`NonTerminal` '{node_symbol}' can not be followed" - " by a comma ',' as it contains children inside brackets" - " '()'" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - # If we encounter an open bracket with no preceeding token, - # then this is invalid - case "(" if symbol == "": - raise ParseError( - "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 2]} " - ) - # Open a new subtree - case "(": - # Find out where we need to parse to get the children - bracket_start = i - bracket_end = bracket_pairs[bracket_start] - children = list(_parse(frm=bracket_start + 2, to=bracket_end)) - - # Advance the tokenizer past the end of that bracket - i = bracket_end + 2 - - # Reset the symbol - node_symbol = symbol - symbol = "" - - # Build the node with it's children - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.NonTerminal(op=op): - if strict: - child_substring = " ".join( - [child.symbol for child in children] - ) - if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 2] - raise ParseError( - f"While {substring=} is parsable, the children" - f" '{child_substring}' is not one of the valid" - f" choices for '{node_symbol} : {rule.choices}." - " To allow this anyways, pass `strict=False` to" - " this call." - ) - - if op is None: - node = Passthrough(node_symbol, children) - else: - node = Container(node_symbol, children, op) - case Grammar.Terminal(op=op): - raise ParseError("Encountered a '(' after a Terminal.") - case None: - raise ParseError( - f"No associated rule with {node_symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case ")" if symbol == "": - # This occurs in repeated brackets and is fine - # > 's(s(a))' - i += 2 - continue - case ")": - # If we reached this bracket, just make sure the parsing algorithm - # is working correctly by checking we are indeed where we think - # we should be which is at `to` - i += 2 - - node_symbol = symbol - symbol = "" # This should be the end of the recursed call anywho - - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError("A ')' should never follow a `NonTerminal`") - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case _: - i += 2 - symbol += c # Append to current token - - # This occurs when we did not encounter any special characters - # like `,`, `(` or `)`. - # I'm pretty sure the only case this can happen is if we have something - # like the string `"b"`, i.e. just a `Leaf` - if symbol != "": - rule = grammar.rules.get(symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError( - "Did not expected to have `NonTerminal` without" - " special characters '(', ')' or ','" - ) - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - yield node - - itr = _parse(frm=1, to=len(string) - 1) - root_token = next(itr, None) - second_token = next(itr, None) - if second_token is not None: - raise ParseError( - "If getting the root as a `Leaf`, then we should have no proceeding tokens." - ) - - match root_token: - case Leaf() | Container(): - return root_token - case Passthrough(): - raise ParseError("Should not have recieved a `Passthrough` as the root token") - case None: - raise ParseError(f"No token was parsed, was the string empty? {string=}") - case _: - assert_never(root_token) - - def parse(grammar: Grammar, string: str) -> Node: # Chunk up the str string_tokens: list[str] = [] @@ -838,7 +609,7 @@ def parse(grammar: Grammar, string: str) -> Node: if len(tokens) < 4: raise ParseError( f"First token was symbol '{symbol}' which is" - f" a `NoneTerminal`, but should have at least 3 more tokens" + f" a `NonTerminal`, but should have at least 3 more tokens" " for a '(', 'child' and a closing ')'" ) @@ -967,7 +738,7 @@ def bfs_grammar( # noqa: C901, D103 symbol: str, *, max_depth: int, - current_depth: int = 1, + current_depth: int = 0, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -980,7 +751,7 @@ def bfs_grammar( # noqa: C901, D103 yield shared_node return # TODO: check - nxt_depth = current_depth + 2 + nxt_depth = current_depth + 1 rule = grammar.rules.get(symbol) match rule: diff --git a/perf.py b/perf.py index 901b22f9c..268248b7d 100644 --- a/perf.py +++ b/perf.py @@ -40,8 +40,6 @@ if __name__ == "__main__": import time - import rich - grammar = Grammar.from_dict(structure) rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -49,7 +47,7 @@ # model = to_model(sample) t0 = time.perf_counter() - samples = 10000 + samples = 10_000 for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -60,5 +58,7 @@ # model = to_model(sample) t1 = time.perf_counter() + import rich + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") rich.print(f"duration for {samples} samples: {t1 - t0}s ") From d01fe9a7288b225f3a7726f4f6604fc2970c6cc2 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 15:55:29 +0100 Subject: [PATCH 11/50] select --- graph.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/graph.py b/graph.py index 185433d1e..ce6db7a9f 100644 --- a/graph.py +++ b/graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from collections import defaultdict from collections.abc import Callable, Iterator from dataclasses import dataclass, field from functools import partial @@ -320,6 +321,87 @@ def _recurse(node_id: int) -> Node: return _recurse(_root) +def select( + root: Node, + *, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +) -> Iterator[Node]: + match how: + case ("symbol", symbol): + for node in root.bfs(): + if node.symbol == symbol: + yield node + case ("depth", depth): + if isinstance(depth, int): + depth = range(depth, depth + 1) + + queue_depth: list[tuple[Node, int]] = [(root, 0)] + while queue_depth: + nxt, d = queue_depth.pop(0) + match nxt: + case Leaf(): + continue + case Passthrough(children=children) | Container(children=children): + if d in depth: + yield nxt + if d < depth.stop: + queue_depth.extend([(child, d + 1) for child in children]) + case _: + assert_never(nxt) + + case ("climb", climb): + if isinstance(climb, int): + climb = range(climb, climb + 1) + + # First, we iterate downwards, populating parent paths back + # up. As the id for a Leaf is shared across all similar leafs + # as well as the fact shared nodes will share the same node id, + # we could have multiple parents per child id. + parents: defaultdict[int, list[Node]] = defaultdict(list) + + # We remove duplicates using a dict and the shared ids, a list would + # end up with duplicates for every leaf. We use this later to begin + # the climb iteration + leafs: dict[int, Node] = {} + + queue_climb: list[Node] = [] + while queue_climb: + nxt = queue_climb.pop(0) + this_id = id(nxt) + match nxt: + case Leaf(): + leafs[this_id] = nxt + case Passthrough(children=children) | Container(children=children): + for child in children: + parents[id(child)].append(nxt) + queue_climb.extend(children) + case _: + assert_never(nxt) + + # Now we work backwards from the leafs for each of the possible parents + # for the node id, yielding if we're within the climb path. If we've gone + # pass the climb value, we can stop iterating there. + climb_stack: list[tuple[Node, int]] = [] + climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) + while climb_stack: + node, climb_value = climb_stack.pop(-1) + if climb_value in climb: + yield node + + if climb_value < climb.stop: + possible_node_parents = parents[id(node)] + climb_stack.extend( + [(p, climb_value + 1) for p in possible_node_parents] + ) + + case _: + assert_never(how) + + def mutate_leaf_parents( root: Node, grammar: Grammar, From 06b5d036fc4a2b9e6334a96f329bbe66eb273bfa Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 18:29:09 +0100 Subject: [PATCH 12/50] Test selection --- graph.py | 62 +++++++++++--------- perf.py | 6 +- test_graph.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 193 insertions(+), 32 deletions(-) diff --git a/graph.py b/graph.py index ce6db7a9f..b552f4bd4 100644 --- a/graph.py +++ b/graph.py @@ -109,9 +109,8 @@ class Leaf(NamedTuple): symbol: str op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) class Container(NamedTuple): @@ -119,18 +118,16 @@ class Container(NamedTuple): children: list[Node] op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) class Passthrough(NamedTuple): symbol: str children: list[Node] - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) Node: TypeAlias = Container | Passthrough | Leaf @@ -332,7 +329,7 @@ def select( ) -> Iterator[Node]: match how: case ("symbol", symbol): - for node in root.bfs(): + for node in bfs_node(root): if node.symbol == symbol: yield node case ("depth", depth): @@ -342,14 +339,17 @@ def select( queue_depth: list[tuple[Node, int]] = [(root, 0)] while queue_depth: nxt, d = queue_depth.pop(0) + if d in depth: + yield nxt + + if d >= depth.stop: + continue + match nxt: case Leaf(): - continue + pass case Passthrough(children=children) | Container(children=children): - if d in depth: - yield nxt - if d < depth.stop: - queue_depth.extend([(child, d + 1) for child in children]) + queue_depth.extend([(child, d + 1) for child in children]) case _: assert_never(nxt) @@ -368,7 +368,7 @@ def select( # the climb iteration leafs: dict[int, Node] = {} - queue_climb: list[Node] = [] + queue_climb: list[Node] = [root] while queue_climb: nxt = queue_climb.pop(0) this_id = id(nxt) @@ -385,17 +385,27 @@ def select( # Now we work backwards from the leafs for each of the possible parents # for the node id, yielding if we're within the climb path. If we've gone # pass the climb value, we can stop iterating there. - climb_stack: list[tuple[Node, int]] = [] - climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) - while climb_stack: - node, climb_value = climb_stack.pop(-1) + climb_queue: list[tuple[Node, int]] = [] + climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) + seen: set[int] = set() + while climb_queue: + node, climb_value = climb_queue.pop(0) + node_id = id(node) + if node_id in seen: + continue + if climb_value in climb: + seen.add(node_id) yield node if climb_value < climb.stop: possible_node_parents = parents[id(node)] - climb_stack.extend( - [(p, climb_value + 1) for p in possible_node_parents] + climb_queue.extend( + [ + (p, climb_value + 1) + for p in possible_node_parents + if id(p) not in seen + ] ) case _: @@ -911,12 +921,10 @@ def _build(_n: Node) -> Iterator[Any]: assert_never(node) -structure = { +grammar = { "S": ( - Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], - nn.Sequential, - ) + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), "O": ["4", "1", "id"], diff --git a/perf.py b/perf.py index 268248b7d..28f7716e2 100644 --- a/perf.py +++ b/perf.py @@ -10,6 +10,7 @@ ReLUConvBN, parse, sample_grammar, + to_model, to_nxgraph, to_string, ) @@ -36,7 +37,6 @@ ), } - if __name__ == "__main__": import time @@ -44,7 +44,7 @@ rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) graph = to_nxgraph(sample) - # model = to_model(sample) + model = to_model(sample) t0 = time.perf_counter() samples = 10_000 @@ -52,7 +52,7 @@ for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) string = to_string(sample) - parse(string=string, grammar=grammar) + node = parse(string=string, grammar=grammar) # graph = to_nxgraph(sample) # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) # model = to_model(sample) diff --git a/test_graph.py b/test_graph.py index 4de247c60..09759a51b 100644 --- a/test_graph.py +++ b/test_graph.py @@ -10,7 +10,10 @@ Node, ParseError, Passthrough, + bfs_node, + dfs_node, parse, + select, to_model, to_node_from_graph, to_nxgraph, @@ -184,7 +187,7 @@ def test_dfs_node_container() -> None: ], op=join, ) - outcome = list(node.dfs()) + outcome = list(dfs_node(node)) expected = [ # First Container( @@ -241,7 +244,7 @@ def test_bfs_node_container() -> None: ], op=join, ) - outcome = list(node.bfs()) + outcome = list(bfs_node(node)) expected = [ # First Container( @@ -279,3 +282,153 @@ def test_bfs_node_container() -> None: ] for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): assert e == o, f"Failed at index {i}" + + +def test_select_symbol() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("symbol", "d"))) + assert selected == [ + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ] + + +def test_select_depth() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("depth", 1))) + assert selected == root.children + + selected = list(select(root, how=("depth", range(1, 3)))) + expected = [ + # Depth 1 + *root.children, + # Depth 2 + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + ] + assert selected == expected + + +def test_select_climb() -> None: + # NOTE: The order is rather arbitrary and not much thought has been given to it. + # However the test still tests a particular order that was done by trial and + # error. Feel free to redo the order if this changes. + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("climb", 0))) + assert selected == [ + Leaf("l3", op=T("l3")), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + Leaf("l1", op=T("l1")), + ] + + selected = list(select(root, how=("climb", range(1, 3)))) + expected = [ + root, + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + ] + for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): + assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" From 91c6910321f37f3b654dd35220b6f8789cd1d0fa Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 19:42:36 +0100 Subject: [PATCH 13/50] Rework mutations --- graph.py | 130 ++++++++++++++++++++------------------------ graph_playground.py | 47 ++++++++++++++++ 2 files changed, 107 insertions(+), 70 deletions(-) create mode 100644 graph_playground.py diff --git a/graph.py b/graph.py index b552f4bd4..be1346905 100644 --- a/graph.py +++ b/graph.py @@ -2,7 +2,7 @@ import itertools from collections import defaultdict -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field from functools import partial from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias @@ -412,87 +412,77 @@ def select( assert_never(how) -def mutate_leaf_parents( +def mutations( root: Node, grammar: Grammar, *, - rng: np.random.Generator, + which: Iterable[Node], + max_mutation_depth: int, + rng_shuffle: np.random.Generator | None = None, variables: dict[str, Node] | None = None, -) -> Node: - """Mutate a node, returning a different possibility for it.""" +) -> Iterator[Node]: + """Mutate nodes, returning all the different possibilities for them. + + Args: + root: The root from which to operate. + grammar: The grammar which holds the rules used for mutation. + which: What nodes to mutate, look at `select()`. + max_mutation_depth: The maximum depth allowed for bfs iteration + on the mutant nodes. + rng_shuffle: Whether to shuffle the return order. This takes place at the place + when considering the possibilities for a given node, and does not follow + the order of `NonTerminal.choices`. + variables: Any predefined values you'd like for different symbols. + + Returns: + A new tree per possible mutation + """ if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") - variables = variables or {} - parents: dict[int, Node] = {} - leaf_parents: list[Node] = [] + variables = variables or {} + mutation_ids = {id(n) for n in which} - def _fill(n: Node, *, parent: Node) -> None: - node_id = id(n) - parents[node_id] = parent - match n: - case Leaf(): - leaf_parents.append(parent) - case Passthrough(_, children) | Container(_, children): - for child in children: - _fill(child, parent=parent) - case _: - assert_never(n) - - for child in root.children: - _fill(child, parent=root) - - # Note, we can have duplicates here, that's fine, we want to weight those - # parents with many leafs more heavily... TODO: Maybe? - chosen_node: Node = rng.choice(leaf_parents) # type: ignore - chosen_node_id = id(chosen_node) - - match chosen_node: - case Passthrough() | Container(): - new_subnode = sample_grammar( - chosen_node.symbol, - grammar, - rng=rng, - # NOTE: subfunction will update variables dict - # with any instantiated `variables` if it doesn't - # exist already in the passed in `variables` - variables=variables, - ) - case Leaf(): - raise ValueError("don't pass leafs") - case _: - assert_never(chosen_node) - - def _build(n: Node): - # If we find the node to replace, replace it. - if id(n) == chosen_node_id: - return new_subnode - - # It may be the case that `sample_grammar` above populated - # `variables`, replacing one of the shared nodes with something - # new. In that case, we want to use the new sampled value wherever - # we encounter that symbol. - shared_node = variables.get(n.symbol) - if shared_node is not None: - return shared_node - - # Otherwise, we just rebuild as needed - match n: + def _inner(node: Node) -> Iterator[Node]: + match node: case Leaf(): - return n - case Container(symbol, children, op): - return Container(symbol, children=[_build(c) for c in children], op=op) - case Passthrough(symbol, children): - return Passthrough(symbol, children=[_build(c) for c in children]) - case _: - assert_never(n) + # We can't mutate leafs as they don't have possible choices to choose from + # by definition so we ignore it even if it's in the set of `mutation_ids` + yield node + case Passthrough(children=children) | Container(children=children): + rule = grammar.rules.get(node.symbol) + if not isinstance(rule, Grammar.NonTerminal): + raise ValueError( + "Expected a `NonTerminal` for symbol '{node.symbol}' from the" + f" grammar but got rule {rule}" + ) - return _build(root) + # If we've already determined the value of this shared symbol + if (existing := variables.get(node.symbol)) is not None: + yield existing + return + # If mutate, we return all possible bfs values from that node. + if id(node) in mutation_ids: + yield from bfs_grammar( + grammar, + node.symbol, + rng_shuffle=rng_shuffle, + max_depth=max_mutation_depth, + variables=variables, + ) + else: + children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] + for new_children in itertools.product(*children_itrs): + node = node._replace(children=list) + new_node = node._replace(children=new_children) + if rule.shared: + variables[new_node.symbol] = node + yield new_node + case _: + assert_never(node) -def mutate_many( - node: Node, grammar: Grammar, *, rng: np.random.Generator -) -> Iterator[Node]: ... + yield from _inner(root) def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: diff --git a/graph_playground.py b/graph_playground.py new file mode 100644 index 000000000..5b995feef --- /dev/null +++ b/graph_playground.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from graph import Grammar, mutations, parse, select, to_string + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p a", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +root = parse(grammar_1, "s(p(s(a), a))") + +selections = list(select(root, how=("climb", range(1, 3)))) +mutants = mutations( + root=root, + grammar=grammar_1, + which=selections, + max_mutation_depth=3, +) +mutants = list(mutants) + +import rich + +rich.print("grammar", grammar_1) +rich.print("root", f"{to_string(root)}") +rich.print("selections", [to_string(s) for s in selections]) +rich.print("mutants", [to_string(m) for m in mutants]) From 149cc8ce01e7518610741f369065cb8edc754888 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:44:19 +0100 Subject: [PATCH 14/50] Fix parsing --- graph.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/graph.py b/graph.py index be1346905..5cf69b093 100644 --- a/graph.py +++ b/graph.py @@ -621,29 +621,31 @@ def parse(grammar: Grammar, string: str) -> Node: bracket_pairs: dict[int, int] = {} for i, tok in enumerate(tokens): match tok: - case ( - "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) - ): + case "," | (_, Leaf()): continue case ")": start = bracket_stack.pop(-1) bracket_pairs[start] = i case "(": bracket_stack.append(i) - case (symbol, Grammar.NonTerminal(shared=True)): + case (symbol, Grammar.NonTerminal(shared=shared)): if i + 1 >= len(tokens): raise ParseError( - f"Symbol '{tok}' is 'shared', implying that it should" + f"Symbol '{tok}' is a `NonTerminal`, implying that it should" " contain some inner elements. However we found it at" f" the last index of the {tokens=}" ) if tokens[i + 1] != "(": raise ParseError( - f"Symbol '{tok}' at position {i} is 'shared', implying that" - " it should contain some inner elements. However it was not" - f" followed by a '(' at position {i + 1} in {tokens=}" + f"Symbol '{tok}' at position {i} is a `NonTerminal`," + " implying that it should contain some inner elements." + f" However it was not followed by a '(' at position {i + 1}" + f" in {tokens=}" ) - _shared_locs[symbol].append(i) + if shared is True: + _shared_locs[symbol].append(i) + case _: + assert_never(tok) # If we have more than one occurence of a shared symbol, # we validate their subtokens match From 17b11a3fda050c1adcb7b61714a3c9aff87b0181 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:46:04 +0100 Subject: [PATCH 15/50] Fix mutation --- graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graph.py b/graph.py index 5cf69b093..8053dfb79 100644 --- a/graph.py +++ b/graph.py @@ -474,10 +474,9 @@ def _inner(node: Node) -> Iterator[Node]: else: children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] for new_children in itertools.product(*children_itrs): - node = node._replace(children=list) new_node = node._replace(children=new_children) if rule.shared: - variables[new_node.symbol] = node + variables[new_node.symbol] = new_node yield new_node case _: assert_never(node) From 0226334bfc133fda97d378dc7c40992ab5e1f0ed Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 15:23:35 +0100 Subject: [PATCH 16/50] fix: stop unpacking of nn.Sequential --- graph.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/graph.py b/graph.py index 8053dfb79..aac96ea6c 100644 --- a/graph.py +++ b/graph.py @@ -8,7 +8,6 @@ from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never -import more_itertools import networkx as nx import numpy as np from torch import nn @@ -47,22 +46,21 @@ def next(self, n: int) -> int: class ReLUConvBN(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + def __init__(self, out_channels, kernel_size, stride, padding): super().__init__() self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv3d( - in_channels, - out_channels, - kernel_size, + nn.LazyConv2d( + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, dilation=2, bias=False, ), - nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), + nn.LazyBatchNorm2d(affine=True, track_running_stats=True), ) def forward(self, x): @@ -73,8 +71,8 @@ class Identity(nn.Module): def __init__(self): super().__init__() - def forward(self): - return self + def forward(self, x): + return x def dfs_node(node: Node) -> Iterator[Node]: @@ -881,30 +879,33 @@ def bfs_grammar( # noqa: C901, D103 def to_model(node: Node) -> Any: """Convert a parse tree node and its children into some object it represents.""" - def _build(_n: Node) -> Iterator[Any]: + def _build(_n: Node) -> list[Any] | Any: match _n: case Leaf(_, op): - yield op() + return op() case Container(_, children, op): # The problem is that each child could be either: # * A single 'thing', in the case of Leaf or Container # * Multiple things, in case it's a passthrough # Hence we flatten them out into a single big children itr - flat_children = more_itertools.collapse( - _build(child) for child in children - ) - yield op(*flat_children) + _l = [] + for child in children: + _b = _build(child) + if isinstance(_b, list): + _l.extend(_b) + continue + _l.append(_b) + + return op(*_l) case Passthrough(_, children): - yield from (_build(child) for child in children) + return [_build(child) for child in children] case _: assert_never(node) match node: case Leaf() | Container(): - itr = _build(node) - obj = next(itr, None) - assert obj is not None, "Should have recieved at least one object" - assert next(itr, None) is None, "Should not have recieved two objects" + obj = _build(node) + assert not isinstance(obj, list) return obj case Passthrough(symbol): raise ValueError(f"Can not call build on a `Passthrough` {symbol}") From 241d9da3e639b9bfedaa473f6d9d76dcc55af9a7 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 16:03:39 +0100 Subject: [PATCH 17/50] tests: add mlp end-to-end test --- test_graph.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test_graph.py b/test_graph.py index 09759a51b..df10575d0 100644 --- a/test_graph.py +++ b/test_graph.py @@ -1,8 +1,12 @@ from __future__ import annotations +import time from dataclasses import dataclass +from functools import partial +import numpy as np import pytest +import torch from graph import ( Container, Grammar, @@ -13,12 +17,14 @@ bfs_node, dfs_node, parse, + sample_grammar, select, to_model, to_node_from_graph, to_nxgraph, to_string, ) +from torch import nn # Leafs @@ -54,6 +60,22 @@ def join(*s: str) -> str: } ) +grammar_3 = Grammar.from_dict( + { + "S": (["mlp", "O"], nn.Sequential), + "mlp": (["L", "O", "S O"], nn.Sequential), + "L": ( + ["linear64 linear128 relu O linear64 relu O", "linear64 elu linear64"], + nn.Sequential, + ), + "O": (["linear64", "linear64 relu", "linear128 elu"], nn.Sequential), + "linear64": partial(nn.LazyLinear, out_features=64), + "linear128": partial(nn.LazyLinear, out_features=64), + "relu": nn.ReLU, + "elu": nn.ELU, + } +) + @pytest.mark.parametrize( ("grammar", "string", "built", "node"), @@ -432,3 +454,20 @@ def test_select_climb() -> None: ] for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" + + +@pytest.mark.parametrize("grammar", [grammar_3]) +def test_sample_grammar_and_build_model(grammar: Grammar): + rng = np.random.default_rng(seed=42) + + x = torch.randn(32, 100) + + t0 = time.perf_counter() + samples = 1_000 + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + model: nn.Module = to_model(sample) + model(x) + assert sum(p.numel() for p in model.parameters()) > 0 + + assert time.perf_counter() - t0 < 1 From 8f186fb648d91ac928eec364e4174aecb347c64e Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 7 Feb 2025 18:19:28 +0100 Subject: [PATCH 18/50] yo --- graph.py | 931 ++++++++++++++++++++++++++++++++++++++++++++++++++ perf.py | 32 ++ test_graph.py | 160 +++++++++ 3 files changed, 1123 insertions(+) create mode 100644 graph.py create mode 100644 perf.py create mode 100644 test_graph.py diff --git a/graph.py b/graph.py new file mode 100644 index 000000000..e8bc8d2a2 --- /dev/null +++ b/graph.py @@ -0,0 +1,931 @@ +from __future__ import annotations + +import itertools +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing_extensions import assert_never + +import more_itertools +import networkx as nx +from torch import nn + +from neps.exceptions import NePSError + +if TYPE_CHECKING: + import numpy as np + + +class ParseError(NePSError): + pass + + +class ReLUConvBN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + super().__init__() + + self.kernel_size = kernel_size + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=1, + bias=False, + ), + nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self): + return self + + +class Leaf(NamedTuple): + symbol: str + op: Callable + + +class Container(NamedTuple): + symbol: str + children: list[Node] + op: Callable + + +class Passthrough(NamedTuple): + symbol: str + children: list[Node] + + +Node: TypeAlias = Container | Passthrough | Leaf + + +@dataclass +class Tree: + root: Container | Leaf + + nodes: dict[int, Node] + + children_ids_of: dict[int, list[int]] + parent_id_of: dict[int, int] + leafs: list[int] + + @classmethod + def from_node(cls, node: Node) -> Tree: + """Create a `Tree` from a node, where node is considered the root.""" + nodes: dict[int, Node] = {} + children_ids_of: dict[int, list[int]] = {} + parent_id_of: dict[int, int] = {} + + def _traverse(n: Node, parent_id: int | None = None) -> None: + node_id = id(n) + nodes[node_id] = n + + if parent_id is not None: + parent_id_of[node_id] = parent_id + children_ids_of[parent_id].append(node_id) + + match n: + case Leaf(): + pass + case Container(_, children, _) | Passthrough(_, children): + children_ids_of[node_id] = [] + for child in children: + _traverse(child, node_id) + case _: + assert_never(n) + + _traverse(node) + + # Validate node is a Container or Leaf + if not isinstance(node, Container | Leaf): + raise ValueError("Root node must be a Container or Leaf") + + return cls( + root=node, + nodes=nodes, + children_ids_of=children_ids_of, + parent_id_of=parent_id_of, + leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], + ) + + +@dataclass +class Grammar: + rules: dict[str, Terminal | NonTerminal] + + class Terminal(NamedTuple): + op: Callable + shared: bool = False + + class NonTerminal(NamedTuple): + choices: list[str] + op: Callable | None = None + shared: bool = False + + @classmethod + def from_dict( + cls, + grammar: dict[ + str, + Callable + | list[str] + | tuple[list[str], Callable] + | Grammar.Terminal + | Grammar.NonTerminal, + ], + ) -> Grammar: + rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} + for symbol, rule in grammar.items(): + match rule: + case Grammar.Terminal() | Grammar.NonTerminal(): + rules[symbol] = rule + case (choices, op) if isinstance(choices, list) and callable(op): + # > e.g. "S": (["A", "A B", "C"], op) + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) + + case choices if isinstance(choices, list): + # > e.g. "S": ["A", "A B", "C"] + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + + case op if callable(op): + # > e.g. "S": op + rules[symbol] = Grammar.Terminal(op, shared=False) + case _: + raise ValueError( + f"The rule for symbol {symbol} is not recognized. Should be" + " a list of of symbols, a callable or a tuple with both." + f"\n Got {rule}" + ) + + return Grammar(rules) + + +def sample_grammar( + symbol: str, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + variables = variables or {} + rule = grammar.rules.get(symbol) + if rule is None: + raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices, op): + chosen_children = rng.choice(choices).split(" ") + children = [ + sample_grammar(child_symbol, grammar, rng=rng, variables=variables) + for child_symbol in chosen_children + ] + if op is None: + node = Passthrough(symbol, children=children) + else: + node = Container(symbol, op=op, children=children) + case _: + assert_never(rule) + + if rule.shared: + variables[symbol] = node + + return node + + +def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: + # Find the unique root (a node with no incoming edges) + _root = next((n for n, d in graph.in_degree if d == 0), None) + if _root is None: + raise ValueError( + "Could not find a root in the given graph (a node with indegree 0)." + ) + + variables: dict[str, Node] = {} + + def _recurse(node_id: int) -> Node: + symbol = graph.nodes[node_id].get("label") + if symbol is None: + raise ValueError(f"Node {node_id} does not have a 'label' property.") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + rule = grammar.rules.get(symbol) + if rule is None: + raise ValueError( + f"Symbol '{symbol}' not found in grammar rules: {grammar.rules.keys()}" + ) + + # Based on the type of rule, construct the proper node + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices=_, op=op): + children = [_recurse(child_id) for child_id in graph.successors(node_id)] + if op is None: + node = Passthrough(symbol, children) + else: + node = Container(symbol, children, op) + case _: + raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") + + if rule.shared: + variables[symbol] = node + + return node + + # Start with the root node + return _recurse(_root) + + +def mutate_leaf_parents( + root: Node, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + """Mutate a node, returning a different possibility for it.""" + if isinstance(root, Leaf): + raise ValueError(f"Can't mutate `Leaf`: {root}") + variables = variables or {} + tree: Tree = Tree.from_node(node=root) + + # Note, we can have duplicates here, that's fine, we want to weight those + # parents with many leafs more heavily... TODO: Maybe? + parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] + + chosen_node_id: int = rng.choice(parents) + chosen_node: Node = tree.nodes[chosen_node_id] + + match chosen_node: + case Passthrough() | Container(): + new_subnode = sample_grammar( + chosen_node.symbol, + grammar, + rng=rng, + # NOTE: subfunction will update variables dict + # with any instantiated `variables` if it doesn't + # exist already in the passed in `variables` + variables=variables, + ) + case Leaf(): + raise ValueError("don't pass leafs") + case _: + assert_never(chosen_node) + + def _build(n: Node): + # If we find the node to replace, replace it. + if id(n) == chosen_node_id: + return new_subnode + + # It may be the case that `sample_grammar` above populated + # `variables`, replacing one of the shared nodes with something + # new. In that case, we want to use the new sampled value wherever + # we encounter that symbol. + shared_node = variables.get(n.symbol) + if shared_node is not None: + return shared_node + + # Otherwise, we just rebuild as needed + match n: + case Leaf(): + return n + case Container(symbol, children, op): + return Container(symbol, children=[_build(c) for c in children], op=op) + case Passthrough(symbol, children): + return Passthrough(symbol, children=[_build(c) for c in children]) + case _: + assert_never(n) + + return _build(root) + + +def mutate_many( + node: Node, grammar: Grammar, *, rng: np.random.Generator +) -> Iterator[Node]: ... + + +# TODO: This has issues as we are using id's, while we may have heirarchical components +# which share the same id. +def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: + nodes: list[tuple[int, dict]] = [] + edges: list[tuple[int, int]] = [] + id_generator: Iterator[int] = itertools.count() + + def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: + node_id = next(id_generator) + match node: + # Atoms are just a node with an edge to its parent + case Leaf(symbol): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + # If we have a passthrough and shouldn't include them, we simply + # forward on the `parent_id` we recieved to the children + case Passthrough(_, children) if include_passthroughs is False: + for child in children: + _recurse_fill_lists(child, parent_id=parent_id) + + # Containers are a node in the graph, with edges to its + # children (direct, or through passthrough) + case Container(symbol, children, _) | Passthrough(symbol, children): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + for child in children: + _recurse_fill_lists(child, parent_id=node_id) + + case _: + assert_never(root.kind) + + graph = nx.DiGraph() + root_id = next(id_generator) + match root: + case Leaf(): + nodes.append((root_id, {"label": root.symbol})) + case Passthrough(_, children) if include_passthroughs is False: + raise ValueError( + f"Can't create a graph starting from a `Passthrough` {root.symbol}, " + " unless `include_passthrough`" + ) + case Container(_, children, _) | Passthrough(_, children): + for child in children: + _recurse_fill_lists(child, parent_id=root_id) + case _: + assert_never(root) + + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, c in enumerate(string): + match c: + case "(": + bracket_stack.append(i) + case ")": + if len(bracket_stack) == 0: + raise ParseError( + f"Encountered mismatched brackets at position {i}" + f" in string '{string}'" + ) + bracket_start = bracket_stack.pop(-1) + bracket_pairs[bracket_start] = i + case _: + continue + + if len(bracket_stack) > 0: + raise ParseError( + "Encountered a mismatch in the number of brackets." + f"The bracket(s) at position {bracket_stack} were never closed" + f" in the string '{string}'" + ) + + variables: dict[str, Node] = {} + + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + symbol = "" + i = frm + while i <= to: # Use a while loop as we may jump ahead in the loop + c = string[i] + match c: + # Ignore whiespace + case " " | "\n" | "\t": + i += 1 + # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed + # out a symbol from the s(a). Should only occur after a ")" + case "," if symbol == "": + assert string[i - 1] == ")" + i += 1 + # If the last character of a substring ends in a comma, this + # is not a valid string. + case "," if i == to: + raise ParseError( + "Got a (sub)string terminating in a ','." + " The ',' indicates something should come after it." + f" {string[frm : to + 1]}" + ) + # Otherwise, it's a valid ',' with a symbol before it + case ",": + i += 1 + node_symbol = symbol + symbol = "" + + rule = grammar.rules.get(node_symbol) + if rule is None: + raise ParseError( + f"Symbol '{node_symbol}' not in grammar" + f" {grammar.rules.keys()}" + ) + + # We parse out the node, even if it's shared, as we need to ensure + # what we parse out would match whatever is in the shared variables. + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(): + raise ParseError( + f"`NonTerminal` '{node_symbol}' can not be followed" + " by a comma ',' as it contains children inside brackets" + " '()'" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + # If we encounter an open bracket with no preceeding token, + # then this is invalid + case "(" if symbol == "": + raise ParseError( + "Encountered an open brace '(' without any" + f" symbol parsed before it in string {string[frm : to + 1]} " + ) + # Open a new subtree + case "(": + assert i in bracket_pairs + + # Find out where we need to parse to get the children + bracket_start = i + bracket_end = bracket_pairs[bracket_start] + assert bracket_end <= to, f"{bracket_end=} > {to=}" + children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + + # Advance the tokenizer past the end of that bracket + i = bracket_end + 1 + + # Reset the symbol + node_symbol = symbol + symbol = "" + + # Build the node with it's children + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.NonTerminal(_, op): + if strict: + child_substring = " ".join( + child.symbol for child in children + ) + if child_substring not in rule.choices: + substring = string[bracket_start : bracket_end + 1] + raise ParseError( + f"While {substring=} is parsable, the children" + f" '{child_substring}' is not one of the valid" + f" choices for '{node_symbol} : {rule.choices}." + " To allow this anyways, pass `strict=False` to" + " this call." + ) + + if op is None: + node = Passthrough(node_symbol, children) + else: + node = Container(node_symbol, children, op) + case Grammar.Terminal(op): + raise ParseError("Encountered a '(' after a Terminal.") + case None: + raise ParseError( + f"No associated rule with {node_symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case ")" if symbol == "": + # This occurs in repeated brackets and is fine + # > 's(s(a))' + i += 1 + continue + case ")": + # If we reached this bracket, just make sure the parsing algorithm + # is working correctly by checking we are indeed where we think + # we should be which is at `to` + assert i == to + i += 1 + + node_symbol = symbol + symbol = "" # This should be the end of the recursed call anywho + + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError("A ')' should never follow a `NonTerminal`") + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case _: + i += 1 + symbol += c # Append to current token + + # This occurs when we did not encounter any special characters + # like `,`, `(` or `)`. + # I'm pretty sure the only case this can happen is if we have something + # like the string `"b"`, i.e. just a `Leaf` + if symbol != "": + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError( + "Did not expected to have `NonTerminal` without" + " special characters '(', ')' or ','" + ) + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + yield node + + itr = _parse(frm=0, to=len(string) - 1) + root_token = next(itr, None) + second_token = next(itr, None) + if second_token is not None: + raise ParseError( + "If getting the root as a `Leaf`, then we should have no proceeding tokens." + ) + + match root_token: + case Leaf() | Container(): + return root_token + case Passthrough(): + raise ParseError("Should not have recieved a `Passthrough` as the root token") + case None: + raise ParseError(f"No token was parsed, was the string empty? {string=}") + case _: + assert_never(root_token) + + +# NOTE: Not sure we want this as a standalone function, but it serves to show some logic +def is_valid( + grammar: Grammar, + node: Node, + *, + already_shared: set[str] | None = None, +) -> bool: + rule = grammar.rules.get(node.symbol) + if rule is None: + raise ValueError( + f"Node has unknown symbol {node.symbol}, valid symbols are" + f" {grammar.rules.keys()}" + ) + + # We should never encounter a situtation where we have some nesting of shared nodes, + # for example, consider the following, where L1 is shared. + # L1 -> x -> ... -> L1 -> x -> ... + already_shared = already_shared or set() + if rule.shared and node.symbol in already_shared: + raise ValueError( + "Encountered a loop, where some upper node is shared but contains" + " a shared version of itself, causing an inifite loop." + ) + + match node: + case Leaf(symbol): + return symbol in grammar.rules + case Container(symbol, children, _) | Passthrough(symbol, children): + s = " ".join(child.symbol for child in children) + + match rule: + case Grammar.Terminal(_): + return s in grammar.rules and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case Grammar.NonTerminal(choices, _): + return s in choices and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case _: + assert_never(rule) + case _: + assert_never(node) + + +# TODO: Optimization, we don't need to recompute shared substrings. +# This is likely not worth it unless we have really deep trees +def to_string(node: Node) -> str: + match node: + case Leaf(symbol): + return symbol + case Passthrough(symbol, children) | Container(symbol, children): + return f"{symbol}({', '.join(to_string(c) for c in children)})" + case _: + assert_never(node) + + +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + stack.extend(reversed(children)) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + queue.extend(children) + + +# TODO: The variables thing can mess up the max depth +def bfs_grammar( + grammar: Grammar, + symbol: str, + *, + max_depth: int, + current_depth: int = 0, + variables: dict[str, Node] | None = None, + rng_shuffle: np.random.Generator | None = None, +) -> Iterator[Node]: + if current_depth > max_depth: + return + + variables = variables or {} + shared_node = variables.get(symbol) + if shared_node is not None: + yield shared_node + return # TODO: check + + nxt_depth = current_depth + 1 + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + if rule.shared: + variables[symbol] = node + yield node + case Grammar.NonTerminal(choices, op): + for choice in choices: + children = choice.split(" ") + child_expansions: list[Iterator] = [ + bfs_grammar( + grammar, + child_symbol, + max_depth=max_depth, + current_depth=nxt_depth, + rng_shuffle=rng_shuffle, + variables=variables, + ) + for child_symbol in children + ] + + if rng_shuffle: + # This works correctly with python lists, but typing for numpy is off + rng_shuffle.shuffle(child_expansions) # type: ignore + + for possible in itertools.product(*child_expansions): + if op is None: + node = Passthrough(symbol, children=list(possible)) + else: + node = Container(symbol, op=op, children=list(possible)) + + if rule.shared: + variables[symbol] = node + + yield node + case None: + raise ValueError( + f"Could not find symbol {symbol} in table with keys{grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + +def to_model(node: Node) -> Any: + def _build(_n: Node) -> Iterator[Any]: + match _n: + case Leaf(_, op): + yield op() + case Container(_, children, op): + # The problem is that each child could be either: + # * A single 'thing', in the case of Leaf or Container + # * Multiple things, in case it's a passthrough + # Hence we flatten them out into a single big children itr + flat_children = more_itertools.collapse( + _build(child) for child in children + ) + import rich + + rich.print(flat_children) + yield op(*flat_children) + case Passthrough(_, children): + yield from (_build(child) for child in children) + case _: + assert_never(node) + + match node: + case Leaf() | Container(): + itr = _build(node) + obj = next(itr, None) + assert obj is not None, "Should have recieved at least one object" + assert next(itr, None) is None, "Should not have recieved two objects" + return obj + case Passthrough(symbol): + raise ValueError(f"Can not call build on a `Passthrough` {symbol}") + case _: + assert_never(node) + + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +# https://stackoverflow.com/a/29597209 +def hierarchy_pos( + G: nx.DiGraph, + root: int, + width: float = 1.0, + vert_gap: float = 0.2, + vert_loc: float = 0, + xcenter: float = 0.5, +) -> dict[int, tuple[float, float]]: + """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + Licensed under Creative Commons Attribution-Share Alike. + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + def _hierarchy_pos( + G, + root, + width=1.0, + vert_gap=0.2, + vert_loc: float = 0, + xcenter=0.5, + pos: dict[int, tuple[float, float]] | None = None, + parent=None, + ) -> dict[int, tuple[float, float]]: + """See hierarchy_pos docstring for most arguments. + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) diff --git a/perf.py b/perf.py new file mode 100644 index 000000000..63e37062e --- /dev/null +++ b/perf.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from functools import partial + +import numpy as np +from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from torch import nn + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +if __name__ == "__main__": + sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) diff --git a/test_graph.py b/test_graph.py new file mode 100644 index 000000000..625383f23 --- /dev/null +++ b/test_graph.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from graph import ( + Container, + Grammar, + Leaf, + Node, + ParseError, + Passthrough, + parse, + to_model, + to_string, +) + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +grammar_2 = Grammar.from_dict( + { + "L1": (["L2 L2 L3"], join), + "L2": Grammar.NonTerminal(["L3"], join, shared=True), + "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), + "a": T("a"), + "b": T("a"), + } +) + + +@pytest.mark.parametrize( + ("grammar", "string", "built", "node"), + [ + (grammar_1, "a", "a", Leaf("a", T("a"))), + (grammar_1, "b", "b", Leaf("b", T("b"))), + ( + grammar_1, + "s(a)", + "[a]", + Container("s", op=join, children=[Leaf("a", T("a"))]), + ), + ( + grammar_1, + "s(p(a, b))", + "[ab]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(a, b), p(s(a)))", + "[ab[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + Passthrough( + "p", + children=[Container("s", children=[Leaf("a", T("a"))], op=join)], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(s(a)))", + "[[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[ + Container( + "s", + children=[Leaf("a", T("a"))], + op=join, + ) + ], + ), + ], + op=join, + ), + ), + ], +) +def test_string_serialization_and_deserialization_correct( + grammar: Grammar, + string: str, + built: str, + node: Node, +) -> None: + # Test parsing + parsed = parse(grammar, string) + assert parsed == node + + # Test serialization + serialized_again = to_string(parsed) + assert serialized_again == string + + # Test building + assert to_model(parsed) == built + + +@pytest.mark.parametrize( + ("grammar", "string"), + [ + (grammar_1, "c"), + (grammar_1, ""), + (grammar_1, "s(a"), + (grammar_1, "p(a, b)"), + (grammar_1, "("), + (grammar_1, "s(a))"), + (grammar_1, "s((a)"), + (grammar_1, "s("), + (grammar_1, "s)"), + (grammar_1, "a, a"), + (grammar_1, "a,"), + (grammar_1, "s, s"), + # Invalid due to shared rule but not sharing values + (grammar_2, "L1(L2(L3(a)), L2(L3(a)), L3(b))"), + ], +) +def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: + with pytest.raises(ParseError): + parse(grammar, string) From 670183a81ecee3c8c6a207a8745fca47fe20acd2 Mon Sep 17 00:00:00 2001 From: Timur Carstensen Date: Fri, 7 Feb 2025 18:56:20 +0100 Subject: [PATCH 19/50] chore: perf testing --- graph.py | 4 ++-- perf.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/graph.py b/graph.py index e8bc8d2a2..b366a9ae4 100644 --- a/graph.py +++ b/graph.py @@ -808,9 +808,9 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - import rich + # import rich - rich.print(flat_children) + # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) diff --git a/perf.py b/perf.py index 63e37062e..901b22f9c 100644 --- a/perf.py +++ b/perf.py @@ -3,13 +3,22 @@ from functools import partial import numpy as np -from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from graph import ( + Grammar, + Identity, + Node, + ReLUConvBN, + parse, + sample_grammar, + to_nxgraph, + to_string, +) from torch import nn structure = { "S": ( Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], + ["C", "reluconvbn", "S", "S C", "O O O", "S S O O O O O O"], nn.Sequential, ) ), @@ -29,4 +38,27 @@ if __name__ == "__main__": - sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) + import time + + import rich + + grammar = Grammar.from_dict(structure) + rng = np.random.default_rng() + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + graph = to_nxgraph(sample) + # model = to_model(sample) + + t0 = time.perf_counter() + samples = 10000 + + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + string = to_string(sample) + parse(string=string, grammar=grammar) + # graph = to_nxgraph(sample) + # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) + # model = to_model(sample) + + t1 = time.perf_counter() + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") + rich.print(f"duration for {samples} samples: {t1 - t0}s ") From 6a1ac28edfd01e48bd4913d261e71a18736c260c Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Sun, 9 Feb 2025 23:43:09 +0100 Subject: [PATCH 20/50] optimizations on parsing and test --- graph.py | 596 ++++++++++++++++++++++++++++++++++---------------- test_graph.py | 121 ++++++++++ 2 files changed, 533 insertions(+), 184 deletions(-) diff --git a/graph.py b/graph.py index b366a9ae4..e440a2896 100644 --- a/graph.py +++ b/graph.py @@ -2,25 +2,49 @@ import itertools from collections.abc import Callable, Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never import more_itertools import networkx as nx +import numpy as np from torch import nn from neps.exceptions import NePSError -if TYPE_CHECKING: - import numpy as np - class ParseError(NePSError): pass +# OPTIM: Calling `np.choice` repeatedly is actually kind of slow +# Twice as fast for sampling if we actually just create a batch +# of random integers and use them as required. +@dataclass +class BufferedRandIntStream: + rng: np.random.Generator + buffer_size: int = 51 + _cur_ix: int = 2 + + MAX_INT: ClassVar[int] = np.iinfo(np.int64).max + _nums: list[int] = field(default_factory=list) + + def next(self, n: int) -> int: + if self._cur_ix >= len(self._nums): + self._nums = self.rng.integers( + self.MAX_INT, size=self.buffer_size, dtype=np.int64 + ).tolist() + + self._cur_ix = 1 + + i = self._nums[self._cur_ix] % n + + self._cur_ix += 2 + return i + + class ReLUConvBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super().__init__() @@ -28,16 +52,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding): self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv2d( + nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=1, + dilation=2, bias=False, ), - nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), ) def forward(self, x): @@ -52,88 +76,91 @@ def forward(self): return self +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + stack.extend(reversed(children)) + case _: + assert_never(nxt) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + queue.extend(children) + case _: + assert_never(nxt) + + class Leaf(NamedTuple): symbol: str op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Container(NamedTuple): symbol: str children: list[Node] op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Passthrough(NamedTuple): symbol: str children: list[Node] + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node -Node: TypeAlias = Container | Passthrough | Leaf - - -@dataclass -class Tree: - root: Container | Leaf - - nodes: dict[int, Node] - - children_ids_of: dict[int, list[int]] - parent_id_of: dict[int, int] - leafs: list[int] - @classmethod - def from_node(cls, node: Node) -> Tree: - """Create a `Tree` from a node, where node is considered the root.""" - nodes: dict[int, Node] = {} - children_ids_of: dict[int, list[int]] = {} - parent_id_of: dict[int, int] = {} - - def _traverse(n: Node, parent_id: int | None = None) -> None: - node_id = id(n) - nodes[node_id] = n - - if parent_id is not None: - parent_id_of[node_id] = parent_id - children_ids_of[parent_id].append(node_id) - - match n: - case Leaf(): - pass - case Container(_, children, _) | Passthrough(_, children): - children_ids_of[node_id] = [] - for child in children: - _traverse(child, node_id) - case _: - assert_never(n) - - _traverse(node) - - # Validate node is a Container or Leaf - if not isinstance(node, Container | Leaf): - raise ValueError("Root node must be a Container or Leaf") - - return cls( - root=node, - nodes=nodes, - children_ids_of=children_ids_of, - parent_id_of=parent_id_of, - leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], - ) +Node: TypeAlias = Container | Passthrough | Leaf @dataclass class Grammar: rules: dict[str, Terminal | NonTerminal] + _shared: dict[str, NonTerminal] = field(init=False) + _leafs: dict[str, Leaf] = field(init=False) class Terminal(NamedTuple): op: Callable - shared: bool = False class NonTerminal(NamedTuple): choices: list[str] op: Callable | None = None shared: bool = False + def __post_init__(self) -> None: + self._shared = { + s: r + for s, r in self.rules.items() + if isinstance(r, Grammar.NonTerminal) and r.shared + } + self._leafs = { + s: Leaf(s, r.op) + for s, r in self.rules.items() + if isinstance(r, Grammar.Terminal) + } + @classmethod def from_dict( cls, @@ -167,11 +194,11 @@ def from_dict( if any(missing): raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") - rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) case op if callable(op): # > e.g. "S": op - rules[symbol] = Grammar.Terminal(op, shared=False) + rules[symbol] = Grammar.Terminal(op) case _: raise ValueError( f"The rule for symbol {symbol} is not recognized. Should be" @@ -186,23 +213,28 @@ def sample_grammar( symbol: str, grammar: Grammar, *, - rng: np.random.Generator, + rng: np.random.Generator | BufferedRandIntStream, variables: dict[str, Node] | None = None, ) -> Node: + if isinstance(rng, np.random.Generator): + rng = BufferedRandIntStream(rng=rng) + variables = variables or {} rule = grammar.rules.get(symbol) if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - match rule: - case Grammar.Terminal(op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(choices, op): - chosen_children = rng.choice(choices).split(" ") + case Grammar.Terminal(): + return grammar._leafs[symbol] + case Grammar.NonTerminal(choices=choices, op=op): + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + i = rng.next(len(choices)) + choice = choices[i] + chosen_children = choice.split(" ") children = [ sample_grammar(child_symbol, grammar, rng=rng, variables=variables) for child_symbol in chosen_children @@ -211,13 +243,13 @@ def sample_grammar( node = Passthrough(symbol, children=children) else: node = Container(symbol, op=op, children=children) - case _: - assert_never(rule) - if rule.shared: - variables[symbol] = node + if rule.shared: + variables[symbol] = node - return node + return node + case _: + assert_never(rule) def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: @@ -225,7 +257,7 @@ def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: _root = next((n for n, d in graph.in_degree if d == 0), None) if _root is None: raise ValueError( - "Could not find a root in the given graph (a node with indegree 0)." + "Could not find a root in the given graph (a node with indegree 1)." ) variables: dict[str, Node] = {} @@ -235,10 +267,6 @@ def _recurse(node_id: int) -> Node: if symbol is None: raise ValueError(f"Node {node_id} does not have a 'label' property.") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - rule = grammar.rules.get(symbol) if rule is None: raise ValueError( @@ -249,18 +277,21 @@ def _recurse(node_id: int) -> Node: match rule: case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(choices=_, op=op): + case Grammar.NonTerminal(op=op): + if (shared_node := variables.get(symbol)) is not None: + return shared_node + children = [_recurse(child_id) for child_id in graph.successors(node_id)] - if op is None: - node = Passthrough(symbol, children) - else: - node = Container(symbol, children, op) + node = ( + Passthrough(symbol, children) + if op is None + else Container(symbol, children, op) + ) + if rule.shared: + variables[symbol] = node case _: raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") - if rule.shared: - variables[symbol] = node - return node # Start with the root node @@ -278,14 +309,29 @@ def mutate_leaf_parents( if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") variables = variables or {} - tree: Tree = Tree.from_node(node=root) + + parents: dict[int, Node] = {} + leaf_parents: list[Node] = [] + + def _fill(n: Node, *, parent: Node) -> None: + node_id = id(n) + parents[node_id] = parent + match n: + case Leaf(): + leaf_parents.append(parent) + case Passthrough(_, children) | Container(_, children): + for child in children: + _fill(child, parent=parent) + case _: + assert_never(n) + + for child in root.children: + _fill(child, parent=root) # Note, we can have duplicates here, that's fine, we want to weight those # parents with many leafs more heavily... TODO: Maybe? - parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] - - chosen_node_id: int = rng.choice(parents) - chosen_node: Node = tree.nodes[chosen_node_id] + chosen_node: Node = rng.choice(leaf_parents) # type: ignore + chosen_node_id = id(chosen_node) match chosen_node: case Passthrough() | Container(): @@ -335,8 +381,6 @@ def mutate_many( ) -> Iterator[Node]: ... -# TODO: This has issues as we are using id's, while we may have heirarchical components -# which share the same id. def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: nodes: list[tuple[int, dict]] = [] edges: list[tuple[int, int]] = [] @@ -370,9 +414,10 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: graph = nx.DiGraph() root_id = next(id_generator) + nodes.append((root_id, {"label": root.symbol})) match root: case Leaf(): - nodes.append((root_id, {"label": root.symbol})) + pass case Passthrough(_, children) if include_passthroughs is False: raise ValueError( f"Can't create a graph starting from a `Passthrough` {root.symbol}, " @@ -389,7 +434,7 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: +def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: bracket_stack: list[int] = [] bracket_pairs: dict[int, int] = {} for i, c in enumerate(string): @@ -397,17 +442,17 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: case "(": bracket_stack.append(i) case ")": - if len(bracket_stack) == 0: + if len(bracket_stack) == 1: raise ParseError( f"Encountered mismatched brackets at position {i}" f" in string '{string}'" ) - bracket_start = bracket_stack.pop(-1) + bracket_start = bracket_stack.pop(0) bracket_pairs[bracket_start] = i case _: continue - if len(bracket_stack) > 0: + if len(bracket_stack) > 1: raise ParseError( "Encountered a mismatch in the number of brackets." f"The bracket(s) at position {bracket_stack} were never closed" @@ -416,31 +461,30 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: variables: dict[str, Node] = {} - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 symbol = "" i = frm while i <= to: # Use a while loop as we may jump ahead in the loop c = string[i] match c: # Ignore whiespace - case " " | "\n" | "\t": - i += 1 + case _ if c in (" \n\t"): + i += 2 # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed # out a symbol from the s(a). Should only occur after a ")" case "," if symbol == "": - assert string[i - 1] == ")" - i += 1 + i += 2 # If the last character of a substring ends in a comma, this # is not a valid string. case "," if i == to: raise ParseError( "Got a (sub)string terminating in a ','." " The ',' indicates something should come after it." - f" {string[frm : to + 1]}" + f" {string[frm : to + 2]}" ) # Otherwise, it's a valid ',' with a symbol before it case ",": - i += 1 + i += 2 node_symbol = symbol symbol = "" @@ -454,7 +498,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # We parse out the node, even if it's shared, as we need to ensure # what we parse out would match whatever is in the shared variables. match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) case Grammar.NonTerminal(): raise ParseError( @@ -486,20 +530,17 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case "(" if symbol == "": raise ParseError( "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 1]} " + f" symbol parsed before it in string {string[frm : to + 2]} " ) # Open a new subtree case "(": - assert i in bracket_pairs - # Find out where we need to parse to get the children bracket_start = i bracket_end = bracket_pairs[bracket_start] - assert bracket_end <= to, f"{bracket_end=} > {to=}" - children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + children = list(_parse(frm=bracket_start + 2, to=bracket_end)) # Advance the tokenizer past the end of that bracket - i = bracket_end + 1 + i = bracket_end + 2 # Reset the symbol node_symbol = symbol @@ -508,13 +549,13 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # Build the node with it's children rule = grammar.rules.get(node_symbol) match rule: - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): if strict: child_substring = " ".join( - child.symbol for child in children + [child.symbol for child in children] ) if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 1] + substring = string[bracket_start : bracket_end + 2] raise ParseError( f"While {substring=} is parsable, the children" f" '{child_substring}' is not one of the valid" @@ -527,7 +568,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 node = Passthrough(node_symbol, children) else: node = Container(node_symbol, children, op) - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): raise ParseError("Encountered a '(' after a Terminal.") case None: raise ParseError( @@ -556,23 +597,22 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case ")" if symbol == "": # This occurs in repeated brackets and is fine # > 's(s(a))' - i += 1 + i += 2 continue case ")": # If we reached this bracket, just make sure the parsing algorithm # is working correctly by checking we are indeed where we think # we should be which is at `to` - assert i == to - i += 1 + i += 2 node_symbol = symbol symbol = "" # This should be the end of the recursed call anywho rule = grammar.rules.get(node_symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError("A ')' should never follow a `NonTerminal`") case None: raise ParseError( @@ -599,7 +639,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node case _: - i += 1 + i += 2 symbol += c # Append to current token # This occurs when we did not encounter any special characters @@ -609,9 +649,9 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 if symbol != "": rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError( "Did not expected to have `NonTerminal` without" " special characters '(', ')' or ','" @@ -626,7 +666,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node - itr = _parse(frm=0, to=len(string) - 1) + itr = _parse(frm=1, to=len(string) - 1) root_token = next(itr, None) second_token = next(itr, None) if second_token is not None: @@ -645,6 +685,218 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 assert_never(root_token) +def parse(grammar: Grammar, string: str) -> Node: + # Chunk up the str + string_tokens: list[str] = [] + brace_count = 0 + symbol = "" + for tok in string: + match tok: + case " ": + continue + case "(": + brace_count += 1 + if len(symbol) == 0: + raise ParseError( + f"Opening bracket '(' must be preceeded by symbol" + f" but was not.\n{string}" + ) + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ")": + brace_count -= 1 + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ",": + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case _: + symbol += tok + + if brace_count != 0: + raise ParseError( + f"Imbalanced braces, got {abs(brace_count)} too many" + f" {'(' if brace_count > 0 else ')'}." + ) + + if len(symbol) > 0: + string_tokens.append(symbol) + + # Convert to concrete tokens + tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] + for symbol in string_tokens: + if symbol in "(),": + tokens.append(symbol) # type: ignore + continue + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(): + tokens.append((symbol, grammar._leafs[symbol])) + case Grammar.NonTerminal(): + tokens.append((symbol, rule)) + case None: + raise ParseError( + f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" + f" a symbol in {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + # If we're being strict that shared elements must be the same, then + # we can do so more cheaply at the beginning by just comparing subtokens + # before we parse. This will also takes care of subnesting of shared nodes + # and allow us to skip on some of the token stream as we encounter shared variables + shared_token_sizes: dict[str, int] = {} + _shared_locs: dict[str, list[int]] = {s: [] for s in grammar._shared} + + # We figure out the substrings of where each shared symbol begings and ends + if _shared_locs: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, tok in enumerate(tokens): + match tok: + case ( + "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) + ): + continue + case ")": + start = bracket_stack.pop(-1) + bracket_pairs[start] = i + case "(": + bracket_stack.append(i) + case (symbol, Grammar.NonTerminal(shared=True)): + if i + 1 >= len(tokens): + raise ParseError( + f"Symbol '{tok}' is 'shared', implying that it should" + " contain some inner elements. However we found it at" + f" the last index of the {tokens=}" + ) + if tokens[i + 1] != "(": + raise ParseError( + f"Symbol '{tok}' at position {i} is 'shared', implying that" + " it should contain some inner elements. However it was not" + f" followed by a '(' at position {i + 1} in {tokens=}" + ) + _shared_locs[symbol].append(i) + + # If we have more than one occurence of a shared symbol, + # we validate their subtokens match + for symbol, symbol_positions in _shared_locs.items(): + first_pos, rest = symbol_positions[0], symbol_positions[1:] + + # Calculate the inner tokens and length + bracket_first_start = first_pos + 1 + bracket_first_end = bracket_pairs[bracket_first_start] + + inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] + shared_symbol_token_size = len(inner_tokens) + shared_token_sizes[symbol] = shared_symbol_token_size + + for symbol_start in rest: + # +2, skip symbol_start and skip opening bracket '(' + symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] + if symbol_tokens != inner_tokens: + raise ParseError( + f"Found mismatch in shared symbol '{symbol}'" + f" with {symbol=} starting at token `{symbol_start}`" + f" and the same symbol at token `{first_pos}` which has" + f" {inner_tokens=}.\n{tokens=}" + ) + + if len(tokens) == 0: + raise ParseError("Recieved an empty strng") + + match tokens[0]: + case (symbol, Leaf()): + if len(tokens) > 1: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `Terminal`, but was proceeded by more token." + f"\n{tokens=}" + ) + _, root = tokens[0] + case (symbol, Grammar.NonTerminal(op=op)): + if op is None: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NonTerminal` that is `passthrough`, i.e. it has no associated" + " operation and can not be the root." + ) + if len(tokens) < 4: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NoneTerminal`, but should have at least 3 more tokens" + " for a '(', 'child' and a closing ')'" + ) + + # NOTE: We don't care about shared here as we validate above that + # a shared variable can not contain itself, and there are no other + # symbols above or on the same level as this one (as it's the root). + # Hence we do not need to interact with `shared` here. + root = Container(symbol=symbol, children=[], op=op) + case "(" | ")" | ",": + raise ParseError("First token can not be a '(', ')' or a ','") + case rule: + assert_never(rule) + + if isinstance(root, Leaf): + return root + + variables: dict[str, Container | Passthrough] = {} + parent_stack: list[Container | Passthrough] = [] + current: Node = root + + token_stream = iter(tokens[1:]) + + for tok in token_stream: + match tok: + case ",": + parent_stack[-1].children.append(current) + case ")": + parent = parent_stack.pop() + parent.children.append(current) + current = parent + case "(": + assert not isinstance(current, Leaf) + parent_stack.append(current) + case (symbol, rule): + if isinstance(rule, Leaf): + current = rule + continue + + if rule.shared and (existing := variables.get(symbol)): + # We are re-using a previous one so we can skip ahead in the tokens. + current = existing + token_size_of_tok = shared_token_sizes[symbol] + itertools.islice(token_stream, token_size_of_tok) # Skips + continue + + if rule.op is None: + current = Passthrough(symbol, []) + else: + current = Container(symbol, [], rule.op) + + if rule.shared: + variables[symbol] = current + case _: + assert_never(tok) + + return current + + # NOTE: Not sure we want this as a standalone function, but it serves to show some logic def is_valid( grammar: Grammar, @@ -660,10 +912,14 @@ def is_valid( ) # We should never encounter a situtation where we have some nesting of shared nodes, - # for example, consider the following, where L1 is shared. - # L1 -> x -> ... -> L1 -> x -> ... + # for example, consider the following, where L2 is shared. + # L2 -> x -> ... -> L1 -> x -> ... already_shared = already_shared or set() - if rule.shared and node.symbol in already_shared: + if ( + isinstance(rule, Grammar.NonTerminal) + and rule.shared + and node.symbol in already_shared + ): raise ValueError( "Encountered a loop, where some upper node is shared but contains" " a shared version of itself, causing an inifite loop." @@ -695,6 +951,7 @@ def is_valid( # TODO: Optimization, we don't need to recompute shared substrings. # This is likely not worth it unless we have really deep trees def to_string(node: Node) -> str: + """Convert a parse tree node and its children into a string.""" match node: case Leaf(symbol): return symbol @@ -704,39 +961,13 @@ def to_string(node: Node) -> str: assert_never(node) -def dfs_node(node: Node) -> Iterator[Node]: - stack: list[Node] = [node] - while stack: - nxt = stack.pop(-1) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - stack.extend(reversed(children)) - - -def bfs_node(node: Node) -> Iterator[Node]: - queue: list[Node] = [node] - while queue: - nxt = queue.pop(0) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - queue.extend(children) - - # TODO: The variables thing can mess up the max depth -def bfs_grammar( +def bfs_grammar( # noqa: C901, D103 grammar: Grammar, symbol: str, *, max_depth: int, - current_depth: int = 0, + current_depth: int = 1, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -749,16 +980,14 @@ def bfs_grammar( yield shared_node return # TODO: check - nxt_depth = current_depth + 1 + nxt_depth = current_depth + 2 rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - if rule.shared: - variables[symbol] = node yield node - case Grammar.NonTerminal(choices, op): + case Grammar.NonTerminal(choices=choices, op=op): for choice in choices: children = choice.split(" ") child_expansions: list[Iterator] = [ @@ -796,6 +1025,8 @@ def bfs_grammar( def to_model(node: Node) -> Any: + """Convert a parse tree node and its children into some object it represents.""" + def _build(_n: Node) -> Iterator[Any]: match _n: case Leaf(_, op): @@ -808,9 +1039,6 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - # import rich - - # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) @@ -838,30 +1066,30 @@ def _build(_n: Node) -> Iterator[Any]: ) ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), - "O": ["3", "1", "id"], + "O": ["4", "1", "id"], "reluconvbn": partial( - ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ReLUConvBN, in_channels=4, out_channels=3, kernel_size=3, stride=1, padding=1 ), "id": Identity, - "3": partial( - nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + "4": partial( + nn.Conv3d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 ), - "1": partial( - nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + "2": partial( + nn.Conv3d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 ), } -# https://stackoverflow.com/a/29597209 +# https://stackoverflow.com/a/29597210 def hierarchy_pos( G: nx.DiGraph, root: int, - width: float = 1.0, - vert_gap: float = 0.2, - vert_loc: float = 0, - xcenter: float = 0.5, + width: float = 2.0, + vert_gap: float = 1.2, + vert_loc: float = 1, + xcenter: float = 1.5, ) -> dict[int, tuple[float, float]]: - """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. Licensed under Creative Commons Attribution-Share Alike. If the graph is a tree this will return the positions to plot this in a @@ -891,10 +1119,10 @@ def hierarchy_pos( def _hierarchy_pos( G, root, - width=1.0, - vert_gap=0.2, - vert_loc: float = 0, - xcenter=0.5, + width=2.0, + vert_gap=1.2, + vert_loc: float = 1, + xcenter=1.5, pos: dict[int, tuple[float, float]] | None = None, parent=None, ) -> dict[int, tuple[float, float]]: @@ -911,9 +1139,9 @@ def _hierarchy_pos( children = list(G.neighbors(root)) if not isinstance(G, nx.DiGraph) and parent is not None: children.remove(parent) - if len(children) != 0: + if len(children) != 1: dx = width / len(children) - nextx = xcenter - width / 2 - dx / 2 + nextx = xcenter - width / 3 - dx / 2 for child in children: nextx += dx pos = _hierarchy_pos( diff --git a/test_graph.py b/test_graph.py index 625383f23..4de247c60 100644 --- a/test_graph.py +++ b/test_graph.py @@ -12,6 +12,8 @@ Passthrough, parse, to_model, + to_node_from_graph, + to_nxgraph, to_string, ) @@ -135,6 +137,11 @@ def test_string_serialization_and_deserialization_correct( # Test building assert to_model(parsed) == built + # Test graph and back again + graph = to_nxgraph(parsed, include_passthroughs=True) + node_again = to_node_from_graph(graph, grammar) + assert parsed == node_again + @pytest.mark.parametrize( ("grammar", "string"), @@ -158,3 +165,117 @@ def test_string_serialization_and_deserialization_correct( def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: with pytest.raises(ParseError): parse(grammar, string) + + +def test_dfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.dfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # go down left depth first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + # go down right depth first + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" + + +def test_bfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.bfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # Second level first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + # Then 3rd level + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" From 38460a13f3a8751a8a0a7dc987027480e5045bf4 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 00:47:24 +0100 Subject: [PATCH 21/50] fix weird numerics and new opt --- graph.py | 321 ++++++++----------------------------------------------- perf.py | 6 +- 2 files changed, 49 insertions(+), 278 deletions(-) diff --git a/graph.py b/graph.py index e440a2896..185433d1e 100644 --- a/graph.py +++ b/graph.py @@ -25,8 +25,8 @@ class ParseError(NePSError): @dataclass class BufferedRandIntStream: rng: np.random.Generator - buffer_size: int = 51 - _cur_ix: int = 2 + buffer_size: int = 50 + _cur_ix: int = 0 MAX_INT: ClassVar[int] = np.iinfo(np.int64).max _nums: list[int] = field(default_factory=list) @@ -37,11 +37,11 @@ def next(self, n: int) -> int: self.MAX_INT, size=self.buffer_size, dtype=np.int64 ).tolist() - self._cur_ix = 1 + self._cur_ix = 0 i = self._nums[self._cur_ix] % n - self._cur_ix += 2 + self._cur_ix += 1 return i @@ -224,33 +224,55 @@ def sample_grammar( if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + stack: list[Container | Passthrough] = [] match rule: case Grammar.Terminal(): return grammar._leafs[symbol] - case Grammar.NonTerminal(choices=choices, op=op): + case Grammar.NonTerminal(choices, op, shared): shared_node = variables.get(symbol) if shared_node is not None: return shared_node - i = rng.next(len(choices)) - choice = choices[i] - chosen_children = choice.split(" ") - children = [ - sample_grammar(child_symbol, grammar, rng=rng, variables=variables) - for child_symbol in chosen_children - ] - if op is None: - node = Passthrough(symbol, children=children) - else: - node = Container(symbol, op=op, children=children) - - if rule.shared: - variables[symbol] = node - - return node + i = rng.next(len(rule.choices)) + initial_sample = rule.choices[i] + children_symbols = initial_sample.split(" ") + root = Passthrough(symbol, []) if op is None else Container(symbol, [], op) + stack.append(root) case _: assert_never(rule) + while stack: + parent = stack.pop() + i = rng.next(len(choices)) + choice = choices[i] + children_symbols = choice.split(" ") + + for child_symbol in children_symbols: + rule = grammar.rules[child_symbol] + match rule: + case Grammar.Terminal(): + parent.children.append(grammar._leafs[child_symbol]) + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(child_symbol) + if shared_node is not None: + parent.children.append(shared_node) + continue + + sub_parent = ( + Passthrough(child_symbol, []) + if op is None + else Container(child_symbol, [], op) + ) + parent.children.append(sub_parent) + stack.append(sub_parent) + + if shared: + variables[child_symbol] = sub_parent + case _: + assert_never(rule) + + return root + def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: # Find the unique root (a node with no incoming edges) @@ -434,257 +456,6 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: - bracket_stack: list[int] = [] - bracket_pairs: dict[int, int] = {} - for i, c in enumerate(string): - match c: - case "(": - bracket_stack.append(i) - case ")": - if len(bracket_stack) == 1: - raise ParseError( - f"Encountered mismatched brackets at position {i}" - f" in string '{string}'" - ) - bracket_start = bracket_stack.pop(0) - bracket_pairs[bracket_start] = i - case _: - continue - - if len(bracket_stack) > 1: - raise ParseError( - "Encountered a mismatch in the number of brackets." - f"The bracket(s) at position {bracket_stack} were never closed" - f" in the string '{string}'" - ) - - variables: dict[str, Node] = {} - - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 - symbol = "" - i = frm - while i <= to: # Use a while loop as we may jump ahead in the loop - c = string[i] - match c: - # Ignore whiespace - case _ if c in (" \n\t"): - i += 2 - # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed - # out a symbol from the s(a). Should only occur after a ")" - case "," if symbol == "": - i += 2 - # If the last character of a substring ends in a comma, this - # is not a valid string. - case "," if i == to: - raise ParseError( - "Got a (sub)string terminating in a ','." - " The ',' indicates something should come after it." - f" {string[frm : to + 2]}" - ) - # Otherwise, it's a valid ',' with a symbol before it - case ",": - i += 2 - node_symbol = symbol - symbol = "" - - rule = grammar.rules.get(node_symbol) - if rule is None: - raise ParseError( - f"Symbol '{node_symbol}' not in grammar" - f" {grammar.rules.keys()}" - ) - - # We parse out the node, even if it's shared, as we need to ensure - # what we parse out would match whatever is in the shared variables. - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(): - raise ParseError( - f"`NonTerminal` '{node_symbol}' can not be followed" - " by a comma ',' as it contains children inside brackets" - " '()'" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - # If we encounter an open bracket with no preceeding token, - # then this is invalid - case "(" if symbol == "": - raise ParseError( - "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 2]} " - ) - # Open a new subtree - case "(": - # Find out where we need to parse to get the children - bracket_start = i - bracket_end = bracket_pairs[bracket_start] - children = list(_parse(frm=bracket_start + 2, to=bracket_end)) - - # Advance the tokenizer past the end of that bracket - i = bracket_end + 2 - - # Reset the symbol - node_symbol = symbol - symbol = "" - - # Build the node with it's children - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.NonTerminal(op=op): - if strict: - child_substring = " ".join( - [child.symbol for child in children] - ) - if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 2] - raise ParseError( - f"While {substring=} is parsable, the children" - f" '{child_substring}' is not one of the valid" - f" choices for '{node_symbol} : {rule.choices}." - " To allow this anyways, pass `strict=False` to" - " this call." - ) - - if op is None: - node = Passthrough(node_symbol, children) - else: - node = Container(node_symbol, children, op) - case Grammar.Terminal(op=op): - raise ParseError("Encountered a '(' after a Terminal.") - case None: - raise ParseError( - f"No associated rule with {node_symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case ")" if symbol == "": - # This occurs in repeated brackets and is fine - # > 's(s(a))' - i += 2 - continue - case ")": - # If we reached this bracket, just make sure the parsing algorithm - # is working correctly by checking we are indeed where we think - # we should be which is at `to` - i += 2 - - node_symbol = symbol - symbol = "" # This should be the end of the recursed call anywho - - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError("A ')' should never follow a `NonTerminal`") - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case _: - i += 2 - symbol += c # Append to current token - - # This occurs when we did not encounter any special characters - # like `,`, `(` or `)`. - # I'm pretty sure the only case this can happen is if we have something - # like the string `"b"`, i.e. just a `Leaf` - if symbol != "": - rule = grammar.rules.get(symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError( - "Did not expected to have `NonTerminal` without" - " special characters '(', ')' or ','" - ) - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - yield node - - itr = _parse(frm=1, to=len(string) - 1) - root_token = next(itr, None) - second_token = next(itr, None) - if second_token is not None: - raise ParseError( - "If getting the root as a `Leaf`, then we should have no proceeding tokens." - ) - - match root_token: - case Leaf() | Container(): - return root_token - case Passthrough(): - raise ParseError("Should not have recieved a `Passthrough` as the root token") - case None: - raise ParseError(f"No token was parsed, was the string empty? {string=}") - case _: - assert_never(root_token) - - def parse(grammar: Grammar, string: str) -> Node: # Chunk up the str string_tokens: list[str] = [] @@ -838,7 +609,7 @@ def parse(grammar: Grammar, string: str) -> Node: if len(tokens) < 4: raise ParseError( f"First token was symbol '{symbol}' which is" - f" a `NoneTerminal`, but should have at least 3 more tokens" + f" a `NonTerminal`, but should have at least 3 more tokens" " for a '(', 'child' and a closing ')'" ) @@ -967,7 +738,7 @@ def bfs_grammar( # noqa: C901, D103 symbol: str, *, max_depth: int, - current_depth: int = 1, + current_depth: int = 0, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -980,7 +751,7 @@ def bfs_grammar( # noqa: C901, D103 yield shared_node return # TODO: check - nxt_depth = current_depth + 2 + nxt_depth = current_depth + 1 rule = grammar.rules.get(symbol) match rule: diff --git a/perf.py b/perf.py index 901b22f9c..268248b7d 100644 --- a/perf.py +++ b/perf.py @@ -40,8 +40,6 @@ if __name__ == "__main__": import time - import rich - grammar = Grammar.from_dict(structure) rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -49,7 +47,7 @@ # model = to_model(sample) t0 = time.perf_counter() - samples = 10000 + samples = 10_000 for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -60,5 +58,7 @@ # model = to_model(sample) t1 = time.perf_counter() + import rich + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") rich.print(f"duration for {samples} samples: {t1 - t0}s ") From 5e54588af6e48c15027f6b469457baa2b9042bf2 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 15:55:29 +0100 Subject: [PATCH 22/50] select --- graph.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/graph.py b/graph.py index 185433d1e..ce6db7a9f 100644 --- a/graph.py +++ b/graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from collections import defaultdict from collections.abc import Callable, Iterator from dataclasses import dataclass, field from functools import partial @@ -320,6 +321,87 @@ def _recurse(node_id: int) -> Node: return _recurse(_root) +def select( + root: Node, + *, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +) -> Iterator[Node]: + match how: + case ("symbol", symbol): + for node in root.bfs(): + if node.symbol == symbol: + yield node + case ("depth", depth): + if isinstance(depth, int): + depth = range(depth, depth + 1) + + queue_depth: list[tuple[Node, int]] = [(root, 0)] + while queue_depth: + nxt, d = queue_depth.pop(0) + match nxt: + case Leaf(): + continue + case Passthrough(children=children) | Container(children=children): + if d in depth: + yield nxt + if d < depth.stop: + queue_depth.extend([(child, d + 1) for child in children]) + case _: + assert_never(nxt) + + case ("climb", climb): + if isinstance(climb, int): + climb = range(climb, climb + 1) + + # First, we iterate downwards, populating parent paths back + # up. As the id for a Leaf is shared across all similar leafs + # as well as the fact shared nodes will share the same node id, + # we could have multiple parents per child id. + parents: defaultdict[int, list[Node]] = defaultdict(list) + + # We remove duplicates using a dict and the shared ids, a list would + # end up with duplicates for every leaf. We use this later to begin + # the climb iteration + leafs: dict[int, Node] = {} + + queue_climb: list[Node] = [] + while queue_climb: + nxt = queue_climb.pop(0) + this_id = id(nxt) + match nxt: + case Leaf(): + leafs[this_id] = nxt + case Passthrough(children=children) | Container(children=children): + for child in children: + parents[id(child)].append(nxt) + queue_climb.extend(children) + case _: + assert_never(nxt) + + # Now we work backwards from the leafs for each of the possible parents + # for the node id, yielding if we're within the climb path. If we've gone + # pass the climb value, we can stop iterating there. + climb_stack: list[tuple[Node, int]] = [] + climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) + while climb_stack: + node, climb_value = climb_stack.pop(-1) + if climb_value in climb: + yield node + + if climb_value < climb.stop: + possible_node_parents = parents[id(node)] + climb_stack.extend( + [(p, climb_value + 1) for p in possible_node_parents] + ) + + case _: + assert_never(how) + + def mutate_leaf_parents( root: Node, grammar: Grammar, From 3e272b45036a58254fb8b943e7955f7cf4415759 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 18:29:09 +0100 Subject: [PATCH 23/50] Test selection --- graph.py | 62 +++++++++++--------- perf.py | 6 +- test_graph.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 193 insertions(+), 32 deletions(-) diff --git a/graph.py b/graph.py index ce6db7a9f..b552f4bd4 100644 --- a/graph.py +++ b/graph.py @@ -109,9 +109,8 @@ class Leaf(NamedTuple): symbol: str op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) class Container(NamedTuple): @@ -119,18 +118,16 @@ class Container(NamedTuple): children: list[Node] op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) class Passthrough(NamedTuple): symbol: str children: list[Node] - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) Node: TypeAlias = Container | Passthrough | Leaf @@ -332,7 +329,7 @@ def select( ) -> Iterator[Node]: match how: case ("symbol", symbol): - for node in root.bfs(): + for node in bfs_node(root): if node.symbol == symbol: yield node case ("depth", depth): @@ -342,14 +339,17 @@ def select( queue_depth: list[tuple[Node, int]] = [(root, 0)] while queue_depth: nxt, d = queue_depth.pop(0) + if d in depth: + yield nxt + + if d >= depth.stop: + continue + match nxt: case Leaf(): - continue + pass case Passthrough(children=children) | Container(children=children): - if d in depth: - yield nxt - if d < depth.stop: - queue_depth.extend([(child, d + 1) for child in children]) + queue_depth.extend([(child, d + 1) for child in children]) case _: assert_never(nxt) @@ -368,7 +368,7 @@ def select( # the climb iteration leafs: dict[int, Node] = {} - queue_climb: list[Node] = [] + queue_climb: list[Node] = [root] while queue_climb: nxt = queue_climb.pop(0) this_id = id(nxt) @@ -385,17 +385,27 @@ def select( # Now we work backwards from the leafs for each of the possible parents # for the node id, yielding if we're within the climb path. If we've gone # pass the climb value, we can stop iterating there. - climb_stack: list[tuple[Node, int]] = [] - climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) - while climb_stack: - node, climb_value = climb_stack.pop(-1) + climb_queue: list[tuple[Node, int]] = [] + climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) + seen: set[int] = set() + while climb_queue: + node, climb_value = climb_queue.pop(0) + node_id = id(node) + if node_id in seen: + continue + if climb_value in climb: + seen.add(node_id) yield node if climb_value < climb.stop: possible_node_parents = parents[id(node)] - climb_stack.extend( - [(p, climb_value + 1) for p in possible_node_parents] + climb_queue.extend( + [ + (p, climb_value + 1) + for p in possible_node_parents + if id(p) not in seen + ] ) case _: @@ -911,12 +921,10 @@ def _build(_n: Node) -> Iterator[Any]: assert_never(node) -structure = { +grammar = { "S": ( - Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], - nn.Sequential, - ) + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), "O": ["4", "1", "id"], diff --git a/perf.py b/perf.py index 268248b7d..28f7716e2 100644 --- a/perf.py +++ b/perf.py @@ -10,6 +10,7 @@ ReLUConvBN, parse, sample_grammar, + to_model, to_nxgraph, to_string, ) @@ -36,7 +37,6 @@ ), } - if __name__ == "__main__": import time @@ -44,7 +44,7 @@ rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) graph = to_nxgraph(sample) - # model = to_model(sample) + model = to_model(sample) t0 = time.perf_counter() samples = 10_000 @@ -52,7 +52,7 @@ for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) string = to_string(sample) - parse(string=string, grammar=grammar) + node = parse(string=string, grammar=grammar) # graph = to_nxgraph(sample) # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) # model = to_model(sample) diff --git a/test_graph.py b/test_graph.py index 4de247c60..09759a51b 100644 --- a/test_graph.py +++ b/test_graph.py @@ -10,7 +10,10 @@ Node, ParseError, Passthrough, + bfs_node, + dfs_node, parse, + select, to_model, to_node_from_graph, to_nxgraph, @@ -184,7 +187,7 @@ def test_dfs_node_container() -> None: ], op=join, ) - outcome = list(node.dfs()) + outcome = list(dfs_node(node)) expected = [ # First Container( @@ -241,7 +244,7 @@ def test_bfs_node_container() -> None: ], op=join, ) - outcome = list(node.bfs()) + outcome = list(bfs_node(node)) expected = [ # First Container( @@ -279,3 +282,153 @@ def test_bfs_node_container() -> None: ] for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): assert e == o, f"Failed at index {i}" + + +def test_select_symbol() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("symbol", "d"))) + assert selected == [ + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ] + + +def test_select_depth() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("depth", 1))) + assert selected == root.children + + selected = list(select(root, how=("depth", range(1, 3)))) + expected = [ + # Depth 1 + *root.children, + # Depth 2 + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + ] + assert selected == expected + + +def test_select_climb() -> None: + # NOTE: The order is rather arbitrary and not much thought has been given to it. + # However the test still tests a particular order that was done by trial and + # error. Feel free to redo the order if this changes. + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("climb", 0))) + assert selected == [ + Leaf("l3", op=T("l3")), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + Leaf("l1", op=T("l1")), + ] + + selected = list(select(root, how=("climb", range(1, 3)))) + expected = [ + root, + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + ] + for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): + assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" From 198d116d0ce20c1f8eeba6b2e7ff070d7399df56 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 19:42:36 +0100 Subject: [PATCH 24/50] Rework mutations --- graph.py | 130 ++++++++++++++++++++------------------------ graph_playground.py | 47 ++++++++++++++++ 2 files changed, 107 insertions(+), 70 deletions(-) create mode 100644 graph_playground.py diff --git a/graph.py b/graph.py index b552f4bd4..be1346905 100644 --- a/graph.py +++ b/graph.py @@ -2,7 +2,7 @@ import itertools from collections import defaultdict -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field from functools import partial from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias @@ -412,87 +412,77 @@ def select( assert_never(how) -def mutate_leaf_parents( +def mutations( root: Node, grammar: Grammar, *, - rng: np.random.Generator, + which: Iterable[Node], + max_mutation_depth: int, + rng_shuffle: np.random.Generator | None = None, variables: dict[str, Node] | None = None, -) -> Node: - """Mutate a node, returning a different possibility for it.""" +) -> Iterator[Node]: + """Mutate nodes, returning all the different possibilities for them. + + Args: + root: The root from which to operate. + grammar: The grammar which holds the rules used for mutation. + which: What nodes to mutate, look at `select()`. + max_mutation_depth: The maximum depth allowed for bfs iteration + on the mutant nodes. + rng_shuffle: Whether to shuffle the return order. This takes place at the place + when considering the possibilities for a given node, and does not follow + the order of `NonTerminal.choices`. + variables: Any predefined values you'd like for different symbols. + + Returns: + A new tree per possible mutation + """ if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") - variables = variables or {} - parents: dict[int, Node] = {} - leaf_parents: list[Node] = [] + variables = variables or {} + mutation_ids = {id(n) for n in which} - def _fill(n: Node, *, parent: Node) -> None: - node_id = id(n) - parents[node_id] = parent - match n: - case Leaf(): - leaf_parents.append(parent) - case Passthrough(_, children) | Container(_, children): - for child in children: - _fill(child, parent=parent) - case _: - assert_never(n) - - for child in root.children: - _fill(child, parent=root) - - # Note, we can have duplicates here, that's fine, we want to weight those - # parents with many leafs more heavily... TODO: Maybe? - chosen_node: Node = rng.choice(leaf_parents) # type: ignore - chosen_node_id = id(chosen_node) - - match chosen_node: - case Passthrough() | Container(): - new_subnode = sample_grammar( - chosen_node.symbol, - grammar, - rng=rng, - # NOTE: subfunction will update variables dict - # with any instantiated `variables` if it doesn't - # exist already in the passed in `variables` - variables=variables, - ) - case Leaf(): - raise ValueError("don't pass leafs") - case _: - assert_never(chosen_node) - - def _build(n: Node): - # If we find the node to replace, replace it. - if id(n) == chosen_node_id: - return new_subnode - - # It may be the case that `sample_grammar` above populated - # `variables`, replacing one of the shared nodes with something - # new. In that case, we want to use the new sampled value wherever - # we encounter that symbol. - shared_node = variables.get(n.symbol) - if shared_node is not None: - return shared_node - - # Otherwise, we just rebuild as needed - match n: + def _inner(node: Node) -> Iterator[Node]: + match node: case Leaf(): - return n - case Container(symbol, children, op): - return Container(symbol, children=[_build(c) for c in children], op=op) - case Passthrough(symbol, children): - return Passthrough(symbol, children=[_build(c) for c in children]) - case _: - assert_never(n) + # We can't mutate leafs as they don't have possible choices to choose from + # by definition so we ignore it even if it's in the set of `mutation_ids` + yield node + case Passthrough(children=children) | Container(children=children): + rule = grammar.rules.get(node.symbol) + if not isinstance(rule, Grammar.NonTerminal): + raise ValueError( + "Expected a `NonTerminal` for symbol '{node.symbol}' from the" + f" grammar but got rule {rule}" + ) - return _build(root) + # If we've already determined the value of this shared symbol + if (existing := variables.get(node.symbol)) is not None: + yield existing + return + # If mutate, we return all possible bfs values from that node. + if id(node) in mutation_ids: + yield from bfs_grammar( + grammar, + node.symbol, + rng_shuffle=rng_shuffle, + max_depth=max_mutation_depth, + variables=variables, + ) + else: + children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] + for new_children in itertools.product(*children_itrs): + node = node._replace(children=list) + new_node = node._replace(children=new_children) + if rule.shared: + variables[new_node.symbol] = node + yield new_node + case _: + assert_never(node) -def mutate_many( - node: Node, grammar: Grammar, *, rng: np.random.Generator -) -> Iterator[Node]: ... + yield from _inner(root) def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: diff --git a/graph_playground.py b/graph_playground.py new file mode 100644 index 000000000..5b995feef --- /dev/null +++ b/graph_playground.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from graph import Grammar, mutations, parse, select, to_string + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p a", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +root = parse(grammar_1, "s(p(s(a), a))") + +selections = list(select(root, how=("climb", range(1, 3)))) +mutants = mutations( + root=root, + grammar=grammar_1, + which=selections, + max_mutation_depth=3, +) +mutants = list(mutants) + +import rich + +rich.print("grammar", grammar_1) +rich.print("root", f"{to_string(root)}") +rich.print("selections", [to_string(s) for s in selections]) +rich.print("mutants", [to_string(m) for m in mutants]) From 383e5ddfec5741c4b3566cd8c20b18c6ecfac9f2 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:44:19 +0100 Subject: [PATCH 25/50] Fix parsing --- graph.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/graph.py b/graph.py index be1346905..5cf69b093 100644 --- a/graph.py +++ b/graph.py @@ -621,29 +621,31 @@ def parse(grammar: Grammar, string: str) -> Node: bracket_pairs: dict[int, int] = {} for i, tok in enumerate(tokens): match tok: - case ( - "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) - ): + case "," | (_, Leaf()): continue case ")": start = bracket_stack.pop(-1) bracket_pairs[start] = i case "(": bracket_stack.append(i) - case (symbol, Grammar.NonTerminal(shared=True)): + case (symbol, Grammar.NonTerminal(shared=shared)): if i + 1 >= len(tokens): raise ParseError( - f"Symbol '{tok}' is 'shared', implying that it should" + f"Symbol '{tok}' is a `NonTerminal`, implying that it should" " contain some inner elements. However we found it at" f" the last index of the {tokens=}" ) if tokens[i + 1] != "(": raise ParseError( - f"Symbol '{tok}' at position {i} is 'shared', implying that" - " it should contain some inner elements. However it was not" - f" followed by a '(' at position {i + 1} in {tokens=}" + f"Symbol '{tok}' at position {i} is a `NonTerminal`," + " implying that it should contain some inner elements." + f" However it was not followed by a '(' at position {i + 1}" + f" in {tokens=}" ) - _shared_locs[symbol].append(i) + if shared is True: + _shared_locs[symbol].append(i) + case _: + assert_never(tok) # If we have more than one occurence of a shared symbol, # we validate their subtokens match From eb46f96554386481f56311d2ae8883b4e08c4beb Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:46:04 +0100 Subject: [PATCH 26/50] Fix mutation --- graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graph.py b/graph.py index 5cf69b093..8053dfb79 100644 --- a/graph.py +++ b/graph.py @@ -474,10 +474,9 @@ def _inner(node: Node) -> Iterator[Node]: else: children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] for new_children in itertools.product(*children_itrs): - node = node._replace(children=list) new_node = node._replace(children=new_children) if rule.shared: - variables[new_node.symbol] = node + variables[new_node.symbol] = new_node yield new_node case _: assert_never(node) From 39c62e9bafe2007d602bb2f3cead7f0baef57349 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 15:23:35 +0100 Subject: [PATCH 27/50] fix: stop unpacking of nn.Sequential --- graph.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/graph.py b/graph.py index 8053dfb79..aac96ea6c 100644 --- a/graph.py +++ b/graph.py @@ -8,7 +8,6 @@ from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never -import more_itertools import networkx as nx import numpy as np from torch import nn @@ -47,22 +46,21 @@ def next(self, n: int) -> int: class ReLUConvBN(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + def __init__(self, out_channels, kernel_size, stride, padding): super().__init__() self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv3d( - in_channels, - out_channels, - kernel_size, + nn.LazyConv2d( + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, dilation=2, bias=False, ), - nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), + nn.LazyBatchNorm2d(affine=True, track_running_stats=True), ) def forward(self, x): @@ -73,8 +71,8 @@ class Identity(nn.Module): def __init__(self): super().__init__() - def forward(self): - return self + def forward(self, x): + return x def dfs_node(node: Node) -> Iterator[Node]: @@ -881,30 +879,33 @@ def bfs_grammar( # noqa: C901, D103 def to_model(node: Node) -> Any: """Convert a parse tree node and its children into some object it represents.""" - def _build(_n: Node) -> Iterator[Any]: + def _build(_n: Node) -> list[Any] | Any: match _n: case Leaf(_, op): - yield op() + return op() case Container(_, children, op): # The problem is that each child could be either: # * A single 'thing', in the case of Leaf or Container # * Multiple things, in case it's a passthrough # Hence we flatten them out into a single big children itr - flat_children = more_itertools.collapse( - _build(child) for child in children - ) - yield op(*flat_children) + _l = [] + for child in children: + _b = _build(child) + if isinstance(_b, list): + _l.extend(_b) + continue + _l.append(_b) + + return op(*_l) case Passthrough(_, children): - yield from (_build(child) for child in children) + return [_build(child) for child in children] case _: assert_never(node) match node: case Leaf() | Container(): - itr = _build(node) - obj = next(itr, None) - assert obj is not None, "Should have recieved at least one object" - assert next(itr, None) is None, "Should not have recieved two objects" + obj = _build(node) + assert not isinstance(obj, list) return obj case Passthrough(symbol): raise ValueError(f"Can not call build on a `Passthrough` {symbol}") From d1dc291f80b527498b2f356384d84178e6c8237e Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 16:08:12 +0100 Subject: [PATCH 28/50] chore: move tests --- test_graph.py => tests/test_graph.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test_graph.py => tests/test_graph.py (100%) diff --git a/test_graph.py b/tests/test_graph.py similarity index 100% rename from test_graph.py rename to tests/test_graph.py From ea8281e6f3b26fa313ce076ca559ce67c1fffbe1 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 7 Feb 2025 18:19:28 +0100 Subject: [PATCH 29/50] yo --- graph.py | 931 ++++++++++++++++++++++++++++++++++++++++++++++++++ perf.py | 32 ++ test_graph.py | 160 +++++++++ 3 files changed, 1123 insertions(+) create mode 100644 graph.py create mode 100644 perf.py create mode 100644 test_graph.py diff --git a/graph.py b/graph.py new file mode 100644 index 000000000..e8bc8d2a2 --- /dev/null +++ b/graph.py @@ -0,0 +1,931 @@ +from __future__ import annotations + +import itertools +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing_extensions import assert_never + +import more_itertools +import networkx as nx +from torch import nn + +from neps.exceptions import NePSError + +if TYPE_CHECKING: + import numpy as np + + +class ParseError(NePSError): + pass + + +class ReLUConvBN(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + super().__init__() + + self.kernel_size = kernel_size + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=1, + bias=False, + ), + nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self): + return self + + +class Leaf(NamedTuple): + symbol: str + op: Callable + + +class Container(NamedTuple): + symbol: str + children: list[Node] + op: Callable + + +class Passthrough(NamedTuple): + symbol: str + children: list[Node] + + +Node: TypeAlias = Container | Passthrough | Leaf + + +@dataclass +class Tree: + root: Container | Leaf + + nodes: dict[int, Node] + + children_ids_of: dict[int, list[int]] + parent_id_of: dict[int, int] + leafs: list[int] + + @classmethod + def from_node(cls, node: Node) -> Tree: + """Create a `Tree` from a node, where node is considered the root.""" + nodes: dict[int, Node] = {} + children_ids_of: dict[int, list[int]] = {} + parent_id_of: dict[int, int] = {} + + def _traverse(n: Node, parent_id: int | None = None) -> None: + node_id = id(n) + nodes[node_id] = n + + if parent_id is not None: + parent_id_of[node_id] = parent_id + children_ids_of[parent_id].append(node_id) + + match n: + case Leaf(): + pass + case Container(_, children, _) | Passthrough(_, children): + children_ids_of[node_id] = [] + for child in children: + _traverse(child, node_id) + case _: + assert_never(n) + + _traverse(node) + + # Validate node is a Container or Leaf + if not isinstance(node, Container | Leaf): + raise ValueError("Root node must be a Container or Leaf") + + return cls( + root=node, + nodes=nodes, + children_ids_of=children_ids_of, + parent_id_of=parent_id_of, + leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], + ) + + +@dataclass +class Grammar: + rules: dict[str, Terminal | NonTerminal] + + class Terminal(NamedTuple): + op: Callable + shared: bool = False + + class NonTerminal(NamedTuple): + choices: list[str] + op: Callable | None = None + shared: bool = False + + @classmethod + def from_dict( + cls, + grammar: dict[ + str, + Callable + | list[str] + | tuple[list[str], Callable] + | Grammar.Terminal + | Grammar.NonTerminal, + ], + ) -> Grammar: + rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} + for symbol, rule in grammar.items(): + match rule: + case Grammar.Terminal() | Grammar.NonTerminal(): + rules[symbol] = rule + case (choices, op) if isinstance(choices, list) and callable(op): + # > e.g. "S": (["A", "A B", "C"], op) + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) + + case choices if isinstance(choices, list): + # > e.g. "S": ["A", "A B", "C"] + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + + case op if callable(op): + # > e.g. "S": op + rules[symbol] = Grammar.Terminal(op, shared=False) + case _: + raise ValueError( + f"The rule for symbol {symbol} is not recognized. Should be" + " a list of of symbols, a callable or a tuple with both." + f"\n Got {rule}" + ) + + return Grammar(rules) + + +def sample_grammar( + symbol: str, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + variables = variables or {} + rule = grammar.rules.get(symbol) + if rule is None: + raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices, op): + chosen_children = rng.choice(choices).split(" ") + children = [ + sample_grammar(child_symbol, grammar, rng=rng, variables=variables) + for child_symbol in chosen_children + ] + if op is None: + node = Passthrough(symbol, children=children) + else: + node = Container(symbol, op=op, children=children) + case _: + assert_never(rule) + + if rule.shared: + variables[symbol] = node + + return node + + +def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: + # Find the unique root (a node with no incoming edges) + _root = next((n for n, d in graph.in_degree if d == 0), None) + if _root is None: + raise ValueError( + "Could not find a root in the given graph (a node with indegree 0)." + ) + + variables: dict[str, Node] = {} + + def _recurse(node_id: int) -> Node: + symbol = graph.nodes[node_id].get("label") + if symbol is None: + raise ValueError(f"Node {node_id} does not have a 'label' property.") + + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + rule = grammar.rules.get(symbol) + if rule is None: + raise ValueError( + f"Symbol '{symbol}' not found in grammar rules: {grammar.rules.keys()}" + ) + + # Based on the type of rule, construct the proper node + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(choices=_, op=op): + children = [_recurse(child_id) for child_id in graph.successors(node_id)] + if op is None: + node = Passthrough(symbol, children) + else: + node = Container(symbol, children, op) + case _: + raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") + + if rule.shared: + variables[symbol] = node + + return node + + # Start with the root node + return _recurse(_root) + + +def mutate_leaf_parents( + root: Node, + grammar: Grammar, + *, + rng: np.random.Generator, + variables: dict[str, Node] | None = None, +) -> Node: + """Mutate a node, returning a different possibility for it.""" + if isinstance(root, Leaf): + raise ValueError(f"Can't mutate `Leaf`: {root}") + variables = variables or {} + tree: Tree = Tree.from_node(node=root) + + # Note, we can have duplicates here, that's fine, we want to weight those + # parents with many leafs more heavily... TODO: Maybe? + parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] + + chosen_node_id: int = rng.choice(parents) + chosen_node: Node = tree.nodes[chosen_node_id] + + match chosen_node: + case Passthrough() | Container(): + new_subnode = sample_grammar( + chosen_node.symbol, + grammar, + rng=rng, + # NOTE: subfunction will update variables dict + # with any instantiated `variables` if it doesn't + # exist already in the passed in `variables` + variables=variables, + ) + case Leaf(): + raise ValueError("don't pass leafs") + case _: + assert_never(chosen_node) + + def _build(n: Node): + # If we find the node to replace, replace it. + if id(n) == chosen_node_id: + return new_subnode + + # It may be the case that `sample_grammar` above populated + # `variables`, replacing one of the shared nodes with something + # new. In that case, we want to use the new sampled value wherever + # we encounter that symbol. + shared_node = variables.get(n.symbol) + if shared_node is not None: + return shared_node + + # Otherwise, we just rebuild as needed + match n: + case Leaf(): + return n + case Container(symbol, children, op): + return Container(symbol, children=[_build(c) for c in children], op=op) + case Passthrough(symbol, children): + return Passthrough(symbol, children=[_build(c) for c in children]) + case _: + assert_never(n) + + return _build(root) + + +def mutate_many( + node: Node, grammar: Grammar, *, rng: np.random.Generator +) -> Iterator[Node]: ... + + +# TODO: This has issues as we are using id's, while we may have heirarchical components +# which share the same id. +def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: + nodes: list[tuple[int, dict]] = [] + edges: list[tuple[int, int]] = [] + id_generator: Iterator[int] = itertools.count() + + def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: + node_id = next(id_generator) + match node: + # Atoms are just a node with an edge to its parent + case Leaf(symbol): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + # If we have a passthrough and shouldn't include them, we simply + # forward on the `parent_id` we recieved to the children + case Passthrough(_, children) if include_passthroughs is False: + for child in children: + _recurse_fill_lists(child, parent_id=parent_id) + + # Containers are a node in the graph, with edges to its + # children (direct, or through passthrough) + case Container(symbol, children, _) | Passthrough(symbol, children): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + for child in children: + _recurse_fill_lists(child, parent_id=node_id) + + case _: + assert_never(root.kind) + + graph = nx.DiGraph() + root_id = next(id_generator) + match root: + case Leaf(): + nodes.append((root_id, {"label": root.symbol})) + case Passthrough(_, children) if include_passthroughs is False: + raise ValueError( + f"Can't create a graph starting from a `Passthrough` {root.symbol}, " + " unless `include_passthrough`" + ) + case Container(_, children, _) | Passthrough(_, children): + for child in children: + _recurse_fill_lists(child, parent_id=root_id) + case _: + assert_never(root) + + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, c in enumerate(string): + match c: + case "(": + bracket_stack.append(i) + case ")": + if len(bracket_stack) == 0: + raise ParseError( + f"Encountered mismatched brackets at position {i}" + f" in string '{string}'" + ) + bracket_start = bracket_stack.pop(-1) + bracket_pairs[bracket_start] = i + case _: + continue + + if len(bracket_stack) > 0: + raise ParseError( + "Encountered a mismatch in the number of brackets." + f"The bracket(s) at position {bracket_stack} were never closed" + f" in the string '{string}'" + ) + + variables: dict[str, Node] = {} + + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + symbol = "" + i = frm + while i <= to: # Use a while loop as we may jump ahead in the loop + c = string[i] + match c: + # Ignore whiespace + case " " | "\n" | "\t": + i += 1 + # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed + # out a symbol from the s(a). Should only occur after a ")" + case "," if symbol == "": + assert string[i - 1] == ")" + i += 1 + # If the last character of a substring ends in a comma, this + # is not a valid string. + case "," if i == to: + raise ParseError( + "Got a (sub)string terminating in a ','." + " The ',' indicates something should come after it." + f" {string[frm : to + 1]}" + ) + # Otherwise, it's a valid ',' with a symbol before it + case ",": + i += 1 + node_symbol = symbol + symbol = "" + + rule = grammar.rules.get(node_symbol) + if rule is None: + raise ParseError( + f"Symbol '{node_symbol}' not in grammar" + f" {grammar.rules.keys()}" + ) + + # We parse out the node, even if it's shared, as we need to ensure + # what we parse out would match whatever is in the shared variables. + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(): + raise ParseError( + f"`NonTerminal` '{node_symbol}' can not be followed" + " by a comma ',' as it contains children inside brackets" + " '()'" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + # If we encounter an open bracket with no preceeding token, + # then this is invalid + case "(" if symbol == "": + raise ParseError( + "Encountered an open brace '(' without any" + f" symbol parsed before it in string {string[frm : to + 1]} " + ) + # Open a new subtree + case "(": + assert i in bracket_pairs + + # Find out where we need to parse to get the children + bracket_start = i + bracket_end = bracket_pairs[bracket_start] + assert bracket_end <= to, f"{bracket_end=} > {to=}" + children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + + # Advance the tokenizer past the end of that bracket + i = bracket_end + 1 + + # Reset the symbol + node_symbol = symbol + symbol = "" + + # Build the node with it's children + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.NonTerminal(_, op): + if strict: + child_substring = " ".join( + child.symbol for child in children + ) + if child_substring not in rule.choices: + substring = string[bracket_start : bracket_end + 1] + raise ParseError( + f"While {substring=} is parsable, the children" + f" '{child_substring}' is not one of the valid" + f" choices for '{node_symbol} : {rule.choices}." + " To allow this anyways, pass `strict=False` to" + " this call." + ) + + if op is None: + node = Passthrough(node_symbol, children) + else: + node = Container(node_symbol, children, op) + case Grammar.Terminal(op): + raise ParseError("Encountered a '(' after a Terminal.") + case None: + raise ParseError( + f"No associated rule with {node_symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case ")" if symbol == "": + # This occurs in repeated brackets and is fine + # > 's(s(a))' + i += 1 + continue + case ")": + # If we reached this bracket, just make sure the parsing algorithm + # is working correctly by checking we are indeed where we think + # we should be which is at `to` + assert i == to + i += 1 + + node_symbol = symbol + symbol = "" # This should be the end of the recursed call anywho + + rule = grammar.rules.get(node_symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(node_symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError("A ')' should never follow a `NonTerminal`") + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + if rule.shared: + shared_node = variables.get(node_symbol) + if shared_node is not None: + if shared_node == node: + node = shared_node # Make sure return the shared instance + else: + other_substring = to_string(shared_node) + raise ParseError( + f"Encountered the substring {string[frm:to]}, where" + f" {node_symbol} is `shared=True`. However we have" + f" also found the substring {other_substring}." + ) + else: + variables[node_symbol] = node + + yield node + case _: + i += 1 + symbol += c # Append to current token + + # This occurs when we did not encounter any special characters + # like `,`, `(` or `)`. + # I'm pretty sure the only case this can happen is if we have something + # like the string `"b"`, i.e. just a `Leaf` + if symbol != "": + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(_, op): + raise ParseError( + "Did not expected to have `NonTerminal` without" + " special characters '(', ')' or ','" + ) + case None: + raise ParseError( + f"No associated rule with {symbol=}. Available" + f"tokens are {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + yield node + + itr = _parse(frm=0, to=len(string) - 1) + root_token = next(itr, None) + second_token = next(itr, None) + if second_token is not None: + raise ParseError( + "If getting the root as a `Leaf`, then we should have no proceeding tokens." + ) + + match root_token: + case Leaf() | Container(): + return root_token + case Passthrough(): + raise ParseError("Should not have recieved a `Passthrough` as the root token") + case None: + raise ParseError(f"No token was parsed, was the string empty? {string=}") + case _: + assert_never(root_token) + + +# NOTE: Not sure we want this as a standalone function, but it serves to show some logic +def is_valid( + grammar: Grammar, + node: Node, + *, + already_shared: set[str] | None = None, +) -> bool: + rule = grammar.rules.get(node.symbol) + if rule is None: + raise ValueError( + f"Node has unknown symbol {node.symbol}, valid symbols are" + f" {grammar.rules.keys()}" + ) + + # We should never encounter a situtation where we have some nesting of shared nodes, + # for example, consider the following, where L1 is shared. + # L1 -> x -> ... -> L1 -> x -> ... + already_shared = already_shared or set() + if rule.shared and node.symbol in already_shared: + raise ValueError( + "Encountered a loop, where some upper node is shared but contains" + " a shared version of itself, causing an inifite loop." + ) + + match node: + case Leaf(symbol): + return symbol in grammar.rules + case Container(symbol, children, _) | Passthrough(symbol, children): + s = " ".join(child.symbol for child in children) + + match rule: + case Grammar.Terminal(_): + return s in grammar.rules and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case Grammar.NonTerminal(choices, _): + return s in choices and all( + is_valid(grammar, child, already_shared=already_shared.copy()) + for child in children + ) + case _: + assert_never(rule) + case _: + assert_never(node) + + +# TODO: Optimization, we don't need to recompute shared substrings. +# This is likely not worth it unless we have really deep trees +def to_string(node: Node) -> str: + match node: + case Leaf(symbol): + return symbol + case Passthrough(symbol, children) | Container(symbol, children): + return f"{symbol}({', '.join(to_string(c) for c in children)})" + case _: + assert_never(node) + + +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + stack.extend(reversed(children)) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + yield nxt + queue.extend(children) + + +# TODO: The variables thing can mess up the max depth +def bfs_grammar( + grammar: Grammar, + symbol: str, + *, + max_depth: int, + current_depth: int = 0, + variables: dict[str, Node] | None = None, + rng_shuffle: np.random.Generator | None = None, +) -> Iterator[Node]: + if current_depth > max_depth: + return + + variables = variables or {} + shared_node = variables.get(symbol) + if shared_node is not None: + yield shared_node + return # TODO: check + + nxt_depth = current_depth + 1 + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(op): + node = Leaf(symbol, op) + if rule.shared: + variables[symbol] = node + yield node + case Grammar.NonTerminal(choices, op): + for choice in choices: + children = choice.split(" ") + child_expansions: list[Iterator] = [ + bfs_grammar( + grammar, + child_symbol, + max_depth=max_depth, + current_depth=nxt_depth, + rng_shuffle=rng_shuffle, + variables=variables, + ) + for child_symbol in children + ] + + if rng_shuffle: + # This works correctly with python lists, but typing for numpy is off + rng_shuffle.shuffle(child_expansions) # type: ignore + + for possible in itertools.product(*child_expansions): + if op is None: + node = Passthrough(symbol, children=list(possible)) + else: + node = Container(symbol, op=op, children=list(possible)) + + if rule.shared: + variables[symbol] = node + + yield node + case None: + raise ValueError( + f"Could not find symbol {symbol} in table with keys{grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + +def to_model(node: Node) -> Any: + def _build(_n: Node) -> Iterator[Any]: + match _n: + case Leaf(_, op): + yield op() + case Container(_, children, op): + # The problem is that each child could be either: + # * A single 'thing', in the case of Leaf or Container + # * Multiple things, in case it's a passthrough + # Hence we flatten them out into a single big children itr + flat_children = more_itertools.collapse( + _build(child) for child in children + ) + import rich + + rich.print(flat_children) + yield op(*flat_children) + case Passthrough(_, children): + yield from (_build(child) for child in children) + case _: + assert_never(node) + + match node: + case Leaf() | Container(): + itr = _build(node) + obj = next(itr, None) + assert obj is not None, "Should have recieved at least one object" + assert next(itr, None) is None, "Should not have recieved two objects" + return obj + case Passthrough(symbol): + raise ValueError(f"Can not call build on a `Passthrough` {symbol}") + case _: + assert_never(node) + + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +# https://stackoverflow.com/a/29597209 +def hierarchy_pos( + G: nx.DiGraph, + root: int, + width: float = 1.0, + vert_gap: float = 0.2, + vert_loc: float = 0, + xcenter: float = 0.5, +) -> dict[int, tuple[float, float]]: + """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + Licensed under Creative Commons Attribution-Share Alike. + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + def _hierarchy_pos( + G, + root, + width=1.0, + vert_gap=0.2, + vert_loc: float = 0, + xcenter=0.5, + pos: dict[int, tuple[float, float]] | None = None, + parent=None, + ) -> dict[int, tuple[float, float]]: + """See hierarchy_pos docstring for most arguments. + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) diff --git a/perf.py b/perf.py new file mode 100644 index 000000000..63e37062e --- /dev/null +++ b/perf.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from functools import partial + +import numpy as np +from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from torch import nn + +structure = { + "S": ( + Grammar.NonTerminal( + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, + ) + ), + "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), + "O": ["3", "1", "id"], + "reluconvbn": partial( + ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "id": Identity, + "3": partial( + nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ), + "1": partial( + nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ), +} + + +if __name__ == "__main__": + sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) diff --git a/test_graph.py b/test_graph.py new file mode 100644 index 000000000..625383f23 --- /dev/null +++ b/test_graph.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest +from graph import ( + Container, + Grammar, + Leaf, + Node, + ParseError, + Passthrough, + parse, + to_model, + to_string, +) + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +grammar_2 = Grammar.from_dict( + { + "L1": (["L2 L2 L3"], join), + "L2": Grammar.NonTerminal(["L3"], join, shared=True), + "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), + "a": T("a"), + "b": T("a"), + } +) + + +@pytest.mark.parametrize( + ("grammar", "string", "built", "node"), + [ + (grammar_1, "a", "a", Leaf("a", T("a"))), + (grammar_1, "b", "b", Leaf("b", T("b"))), + ( + grammar_1, + "s(a)", + "[a]", + Container("s", op=join, children=[Leaf("a", T("a"))]), + ), + ( + grammar_1, + "s(p(a, b))", + "[ab]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(a, b), p(s(a)))", + "[ab[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[Leaf("a", T("a")), Leaf("b", T("b"))], + ), + Passthrough( + "p", + children=[Container("s", children=[Leaf("a", T("a"))], op=join)], + ), + ], + op=join, + ), + ), + ( + grammar_1, + "s(p(s(a)))", + "[[a]]", + Container( + "s", + children=[ + Passthrough( + "p", + children=[ + Container( + "s", + children=[Leaf("a", T("a"))], + op=join, + ) + ], + ), + ], + op=join, + ), + ), + ], +) +def test_string_serialization_and_deserialization_correct( + grammar: Grammar, + string: str, + built: str, + node: Node, +) -> None: + # Test parsing + parsed = parse(grammar, string) + assert parsed == node + + # Test serialization + serialized_again = to_string(parsed) + assert serialized_again == string + + # Test building + assert to_model(parsed) == built + + +@pytest.mark.parametrize( + ("grammar", "string"), + [ + (grammar_1, "c"), + (grammar_1, ""), + (grammar_1, "s(a"), + (grammar_1, "p(a, b)"), + (grammar_1, "("), + (grammar_1, "s(a))"), + (grammar_1, "s((a)"), + (grammar_1, "s("), + (grammar_1, "s)"), + (grammar_1, "a, a"), + (grammar_1, "a,"), + (grammar_1, "s, s"), + # Invalid due to shared rule but not sharing values + (grammar_2, "L1(L2(L3(a)), L2(L3(a)), L3(b))"), + ], +) +def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: + with pytest.raises(ParseError): + parse(grammar, string) From c8dd67aca51237c147b63f396bda79aed0fd8613 Mon Sep 17 00:00:00 2001 From: Timur Carstensen Date: Fri, 7 Feb 2025 18:56:20 +0100 Subject: [PATCH 30/50] chore: perf testing --- graph.py | 4 ++-- perf.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/graph.py b/graph.py index e8bc8d2a2..b366a9ae4 100644 --- a/graph.py +++ b/graph.py @@ -808,9 +808,9 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - import rich + # import rich - rich.print(flat_children) + # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) diff --git a/perf.py b/perf.py index 63e37062e..901b22f9c 100644 --- a/perf.py +++ b/perf.py @@ -3,13 +3,22 @@ from functools import partial import numpy as np -from graph import Grammar, Identity, ReLUConvBN, sample_grammar +from graph import ( + Grammar, + Identity, + Node, + ReLUConvBN, + parse, + sample_grammar, + to_nxgraph, + to_string, +) from torch import nn structure = { "S": ( Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], + ["C", "reluconvbn", "S", "S C", "O O O", "S S O O O O O O"], nn.Sequential, ) ), @@ -29,4 +38,27 @@ if __name__ == "__main__": - sample = sample_grammar("S", grammar=grammar, rng=np.random.default_rng()) + import time + + import rich + + grammar = Grammar.from_dict(structure) + rng = np.random.default_rng() + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + graph = to_nxgraph(sample) + # model = to_model(sample) + + t0 = time.perf_counter() + samples = 10000 + + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + string = to_string(sample) + parse(string=string, grammar=grammar) + # graph = to_nxgraph(sample) + # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) + # model = to_model(sample) + + t1 = time.perf_counter() + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") + rich.print(f"duration for {samples} samples: {t1 - t0}s ") From ba82313ce1a1631548cb05cd138cb499ea8baf28 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Sun, 9 Feb 2025 23:43:09 +0100 Subject: [PATCH 31/50] optimizations on parsing and test --- graph.py | 596 ++++++++++++++++++++++++++++++++++---------------- test_graph.py | 121 ++++++++++ 2 files changed, 533 insertions(+), 184 deletions(-) diff --git a/graph.py b/graph.py index b366a9ae4..e440a2896 100644 --- a/graph.py +++ b/graph.py @@ -2,25 +2,49 @@ import itertools from collections.abc import Callable, Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias +from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never import more_itertools import networkx as nx +import numpy as np from torch import nn from neps.exceptions import NePSError -if TYPE_CHECKING: - import numpy as np - class ParseError(NePSError): pass +# OPTIM: Calling `np.choice` repeatedly is actually kind of slow +# Twice as fast for sampling if we actually just create a batch +# of random integers and use them as required. +@dataclass +class BufferedRandIntStream: + rng: np.random.Generator + buffer_size: int = 51 + _cur_ix: int = 2 + + MAX_INT: ClassVar[int] = np.iinfo(np.int64).max + _nums: list[int] = field(default_factory=list) + + def next(self, n: int) -> int: + if self._cur_ix >= len(self._nums): + self._nums = self.rng.integers( + self.MAX_INT, size=self.buffer_size, dtype=np.int64 + ).tolist() + + self._cur_ix = 1 + + i = self._nums[self._cur_ix] % n + + self._cur_ix += 2 + return i + + class ReLUConvBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super().__init__() @@ -28,16 +52,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding): self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv2d( + nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=1, + dilation=2, bias=False, ), - nn.BatchNorm2d(out_channels, affine=True, track_running_stats=True), + nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), ) def forward(self, x): @@ -52,88 +76,91 @@ def forward(self): return self +def dfs_node(node: Node) -> Iterator[Node]: + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + stack.extend(reversed(children)) + case _: + assert_never(nxt) + + +def bfs_node(node: Node) -> Iterator[Node]: + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + queue.extend(children) + case _: + assert_never(nxt) + + class Leaf(NamedTuple): symbol: str op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Container(NamedTuple): symbol: str children: list[Node] op: Callable + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node + class Passthrough(NamedTuple): symbol: str children: list[Node] + # Attach methods to nodes + dfs = dfs_node + bfs = bfs_node -Node: TypeAlias = Container | Passthrough | Leaf - - -@dataclass -class Tree: - root: Container | Leaf - - nodes: dict[int, Node] - - children_ids_of: dict[int, list[int]] - parent_id_of: dict[int, int] - leafs: list[int] - @classmethod - def from_node(cls, node: Node) -> Tree: - """Create a `Tree` from a node, where node is considered the root.""" - nodes: dict[int, Node] = {} - children_ids_of: dict[int, list[int]] = {} - parent_id_of: dict[int, int] = {} - - def _traverse(n: Node, parent_id: int | None = None) -> None: - node_id = id(n) - nodes[node_id] = n - - if parent_id is not None: - parent_id_of[node_id] = parent_id - children_ids_of[parent_id].append(node_id) - - match n: - case Leaf(): - pass - case Container(_, children, _) | Passthrough(_, children): - children_ids_of[node_id] = [] - for child in children: - _traverse(child, node_id) - case _: - assert_never(n) - - _traverse(node) - - # Validate node is a Container or Leaf - if not isinstance(node, Container | Leaf): - raise ValueError("Root node must be a Container or Leaf") - - return cls( - root=node, - nodes=nodes, - children_ids_of=children_ids_of, - parent_id_of=parent_id_of, - leafs=[nid for nid, n in nodes.items() if isinstance(n, Leaf)], - ) +Node: TypeAlias = Container | Passthrough | Leaf @dataclass class Grammar: rules: dict[str, Terminal | NonTerminal] + _shared: dict[str, NonTerminal] = field(init=False) + _leafs: dict[str, Leaf] = field(init=False) class Terminal(NamedTuple): op: Callable - shared: bool = False class NonTerminal(NamedTuple): choices: list[str] op: Callable | None = None shared: bool = False + def __post_init__(self) -> None: + self._shared = { + s: r + for s, r in self.rules.items() + if isinstance(r, Grammar.NonTerminal) and r.shared + } + self._leafs = { + s: Leaf(s, r.op) + for s, r in self.rules.items() + if isinstance(r, Grammar.Terminal) + } + @classmethod def from_dict( cls, @@ -167,11 +194,11 @@ def from_dict( if any(missing): raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") - rules[symbol] = Grammar.NonTerminal(choices, None, shared=False) + rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) case op if callable(op): # > e.g. "S": op - rules[symbol] = Grammar.Terminal(op, shared=False) + rules[symbol] = Grammar.Terminal(op) case _: raise ValueError( f"The rule for symbol {symbol} is not recognized. Should be" @@ -186,23 +213,28 @@ def sample_grammar( symbol: str, grammar: Grammar, *, - rng: np.random.Generator, + rng: np.random.Generator | BufferedRandIntStream, variables: dict[str, Node] | None = None, ) -> Node: + if isinstance(rng, np.random.Generator): + rng = BufferedRandIntStream(rng=rng) + variables = variables or {} rule = grammar.rules.get(symbol) if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - match rule: - case Grammar.Terminal(op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(choices, op): - chosen_children = rng.choice(choices).split(" ") + case Grammar.Terminal(): + return grammar._leafs[symbol] + case Grammar.NonTerminal(choices=choices, op=op): + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + i = rng.next(len(choices)) + choice = choices[i] + chosen_children = choice.split(" ") children = [ sample_grammar(child_symbol, grammar, rng=rng, variables=variables) for child_symbol in chosen_children @@ -211,13 +243,13 @@ def sample_grammar( node = Passthrough(symbol, children=children) else: node = Container(symbol, op=op, children=children) - case _: - assert_never(rule) - if rule.shared: - variables[symbol] = node + if rule.shared: + variables[symbol] = node - return node + return node + case _: + assert_never(rule) def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: @@ -225,7 +257,7 @@ def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: _root = next((n for n, d in graph.in_degree if d == 0), None) if _root is None: raise ValueError( - "Could not find a root in the given graph (a node with indegree 0)." + "Could not find a root in the given graph (a node with indegree 1)." ) variables: dict[str, Node] = {} @@ -235,10 +267,6 @@ def _recurse(node_id: int) -> Node: if symbol is None: raise ValueError(f"Node {node_id} does not have a 'label' property.") - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - rule = grammar.rules.get(symbol) if rule is None: raise ValueError( @@ -249,18 +277,21 @@ def _recurse(node_id: int) -> Node: match rule: case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(choices=_, op=op): + case Grammar.NonTerminal(op=op): + if (shared_node := variables.get(symbol)) is not None: + return shared_node + children = [_recurse(child_id) for child_id in graph.successors(node_id)] - if op is None: - node = Passthrough(symbol, children) - else: - node = Container(symbol, children, op) + node = ( + Passthrough(symbol, children) + if op is None + else Container(symbol, children, op) + ) + if rule.shared: + variables[symbol] = node case _: raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") - if rule.shared: - variables[symbol] = node - return node # Start with the root node @@ -278,14 +309,29 @@ def mutate_leaf_parents( if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") variables = variables or {} - tree: Tree = Tree.from_node(node=root) + + parents: dict[int, Node] = {} + leaf_parents: list[Node] = [] + + def _fill(n: Node, *, parent: Node) -> None: + node_id = id(n) + parents[node_id] = parent + match n: + case Leaf(): + leaf_parents.append(parent) + case Passthrough(_, children) | Container(_, children): + for child in children: + _fill(child, parent=parent) + case _: + assert_never(n) + + for child in root.children: + _fill(child, parent=root) # Note, we can have duplicates here, that's fine, we want to weight those # parents with many leafs more heavily... TODO: Maybe? - parents: list[int] = [tree.parent_id_of[leaf] for leaf in tree.leafs] - - chosen_node_id: int = rng.choice(parents) - chosen_node: Node = tree.nodes[chosen_node_id] + chosen_node: Node = rng.choice(leaf_parents) # type: ignore + chosen_node_id = id(chosen_node) match chosen_node: case Passthrough() | Container(): @@ -335,8 +381,6 @@ def mutate_many( ) -> Iterator[Node]: ... -# TODO: This has issues as we are using id's, while we may have heirarchical components -# which share the same id. def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: nodes: list[tuple[int, dict]] = [] edges: list[tuple[int, int]] = [] @@ -370,9 +414,10 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: graph = nx.DiGraph() root_id = next(id_generator) + nodes.append((root_id, {"label": root.symbol})) match root: case Leaf(): - nodes.append((root_id, {"label": root.symbol})) + pass case Passthrough(_, children) if include_passthroughs is False: raise ValueError( f"Can't create a graph starting from a `Passthrough` {root.symbol}, " @@ -389,7 +434,7 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: +def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: bracket_stack: list[int] = [] bracket_pairs: dict[int, int] = {} for i, c in enumerate(string): @@ -397,17 +442,17 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: case "(": bracket_stack.append(i) case ")": - if len(bracket_stack) == 0: + if len(bracket_stack) == 1: raise ParseError( f"Encountered mismatched brackets at position {i}" f" in string '{string}'" ) - bracket_start = bracket_stack.pop(-1) + bracket_start = bracket_stack.pop(0) bracket_pairs[bracket_start] = i case _: continue - if len(bracket_stack) > 0: + if len(bracket_stack) > 1: raise ParseError( "Encountered a mismatch in the number of brackets." f"The bracket(s) at position {bracket_stack} were never closed" @@ -416,31 +461,30 @@ def parse(grammar: Grammar, string: str, *, strict: bool = True) -> Node: variables: dict[str, Node] = {} - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 + def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 symbol = "" i = frm while i <= to: # Use a while loop as we may jump ahead in the loop c = string[i] match c: # Ignore whiespace - case " " | "\n" | "\t": - i += 1 + case _ if c in (" \n\t"): + i += 2 # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed # out a symbol from the s(a). Should only occur after a ")" case "," if symbol == "": - assert string[i - 1] == ")" - i += 1 + i += 2 # If the last character of a substring ends in a comma, this # is not a valid string. case "," if i == to: raise ParseError( "Got a (sub)string terminating in a ','." " The ',' indicates something should come after it." - f" {string[frm : to + 1]}" + f" {string[frm : to + 2]}" ) # Otherwise, it's a valid ',' with a symbol before it case ",": - i += 1 + i += 2 node_symbol = symbol symbol = "" @@ -454,7 +498,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # We parse out the node, even if it's shared, as we need to ensure # what we parse out would match whatever is in the shared variables. match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) case Grammar.NonTerminal(): raise ParseError( @@ -486,20 +530,17 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case "(" if symbol == "": raise ParseError( "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 1]} " + f" symbol parsed before it in string {string[frm : to + 2]} " ) # Open a new subtree case "(": - assert i in bracket_pairs - # Find out where we need to parse to get the children bracket_start = i bracket_end = bracket_pairs[bracket_start] - assert bracket_end <= to, f"{bracket_end=} > {to=}" - children = list(_parse(frm=bracket_start + 1, to=bracket_end)) + children = list(_parse(frm=bracket_start + 2, to=bracket_end)) # Advance the tokenizer past the end of that bracket - i = bracket_end + 1 + i = bracket_end + 2 # Reset the symbol node_symbol = symbol @@ -508,13 +549,13 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 # Build the node with it's children rule = grammar.rules.get(node_symbol) match rule: - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): if strict: child_substring = " ".join( - child.symbol for child in children + [child.symbol for child in children] ) if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 1] + substring = string[bracket_start : bracket_end + 2] raise ParseError( f"While {substring=} is parsable, the children" f" '{child_substring}' is not one of the valid" @@ -527,7 +568,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 node = Passthrough(node_symbol, children) else: node = Container(node_symbol, children, op) - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): raise ParseError("Encountered a '(' after a Terminal.") case None: raise ParseError( @@ -556,23 +597,22 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 case ")" if symbol == "": # This occurs in repeated brackets and is fine # > 's(s(a))' - i += 1 + i += 2 continue case ")": # If we reached this bracket, just make sure the parsing algorithm # is working correctly by checking we are indeed where we think # we should be which is at `to` - assert i == to - i += 1 + i += 2 node_symbol = symbol symbol = "" # This should be the end of the recursed call anywho rule = grammar.rules.get(node_symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(node_symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError("A ')' should never follow a `NonTerminal`") case None: raise ParseError( @@ -599,7 +639,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node case _: - i += 1 + i += 2 symbol += c # Append to current token # This occurs when we did not encounter any special characters @@ -609,9 +649,9 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 if symbol != "": rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - case Grammar.NonTerminal(_, op): + case Grammar.NonTerminal(op=op): raise ParseError( "Did not expected to have `NonTerminal` without" " special characters '(', ')' or ','" @@ -626,7 +666,7 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 yield node - itr = _parse(frm=0, to=len(string) - 1) + itr = _parse(frm=1, to=len(string) - 1) root_token = next(itr, None) second_token = next(itr, None) if second_token is not None: @@ -645,6 +685,218 @@ def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: C901, PLR0912, PLR0915 assert_never(root_token) +def parse(grammar: Grammar, string: str) -> Node: + # Chunk up the str + string_tokens: list[str] = [] + brace_count = 0 + symbol = "" + for tok in string: + match tok: + case " ": + continue + case "(": + brace_count += 1 + if len(symbol) == 0: + raise ParseError( + f"Opening bracket '(' must be preceeded by symbol" + f" but was not.\n{string}" + ) + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ")": + brace_count -= 1 + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ",": + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case _: + symbol += tok + + if brace_count != 0: + raise ParseError( + f"Imbalanced braces, got {abs(brace_count)} too many" + f" {'(' if brace_count > 0 else ')'}." + ) + + if len(symbol) > 0: + string_tokens.append(symbol) + + # Convert to concrete tokens + tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] + for symbol in string_tokens: + if symbol in "(),": + tokens.append(symbol) # type: ignore + continue + + rule = grammar.rules.get(symbol) + match rule: + case Grammar.Terminal(): + tokens.append((symbol, grammar._leafs[symbol])) + case Grammar.NonTerminal(): + tokens.append((symbol, rule)) + case None: + raise ParseError( + f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" + f" a symbol in {grammar.rules.keys()}" + ) + case _: + assert_never(rule) + + # If we're being strict that shared elements must be the same, then + # we can do so more cheaply at the beginning by just comparing subtokens + # before we parse. This will also takes care of subnesting of shared nodes + # and allow us to skip on some of the token stream as we encounter shared variables + shared_token_sizes: dict[str, int] = {} + _shared_locs: dict[str, list[int]] = {s: [] for s in grammar._shared} + + # We figure out the substrings of where each shared symbol begings and ends + if _shared_locs: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, tok in enumerate(tokens): + match tok: + case ( + "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) + ): + continue + case ")": + start = bracket_stack.pop(-1) + bracket_pairs[start] = i + case "(": + bracket_stack.append(i) + case (symbol, Grammar.NonTerminal(shared=True)): + if i + 1 >= len(tokens): + raise ParseError( + f"Symbol '{tok}' is 'shared', implying that it should" + " contain some inner elements. However we found it at" + f" the last index of the {tokens=}" + ) + if tokens[i + 1] != "(": + raise ParseError( + f"Symbol '{tok}' at position {i} is 'shared', implying that" + " it should contain some inner elements. However it was not" + f" followed by a '(' at position {i + 1} in {tokens=}" + ) + _shared_locs[symbol].append(i) + + # If we have more than one occurence of a shared symbol, + # we validate their subtokens match + for symbol, symbol_positions in _shared_locs.items(): + first_pos, rest = symbol_positions[0], symbol_positions[1:] + + # Calculate the inner tokens and length + bracket_first_start = first_pos + 1 + bracket_first_end = bracket_pairs[bracket_first_start] + + inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] + shared_symbol_token_size = len(inner_tokens) + shared_token_sizes[symbol] = shared_symbol_token_size + + for symbol_start in rest: + # +2, skip symbol_start and skip opening bracket '(' + symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] + if symbol_tokens != inner_tokens: + raise ParseError( + f"Found mismatch in shared symbol '{symbol}'" + f" with {symbol=} starting at token `{symbol_start}`" + f" and the same symbol at token `{first_pos}` which has" + f" {inner_tokens=}.\n{tokens=}" + ) + + if len(tokens) == 0: + raise ParseError("Recieved an empty strng") + + match tokens[0]: + case (symbol, Leaf()): + if len(tokens) > 1: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `Terminal`, but was proceeded by more token." + f"\n{tokens=}" + ) + _, root = tokens[0] + case (symbol, Grammar.NonTerminal(op=op)): + if op is None: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NonTerminal` that is `passthrough`, i.e. it has no associated" + " operation and can not be the root." + ) + if len(tokens) < 4: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NoneTerminal`, but should have at least 3 more tokens" + " for a '(', 'child' and a closing ')'" + ) + + # NOTE: We don't care about shared here as we validate above that + # a shared variable can not contain itself, and there are no other + # symbols above or on the same level as this one (as it's the root). + # Hence we do not need to interact with `shared` here. + root = Container(symbol=symbol, children=[], op=op) + case "(" | ")" | ",": + raise ParseError("First token can not be a '(', ')' or a ','") + case rule: + assert_never(rule) + + if isinstance(root, Leaf): + return root + + variables: dict[str, Container | Passthrough] = {} + parent_stack: list[Container | Passthrough] = [] + current: Node = root + + token_stream = iter(tokens[1:]) + + for tok in token_stream: + match tok: + case ",": + parent_stack[-1].children.append(current) + case ")": + parent = parent_stack.pop() + parent.children.append(current) + current = parent + case "(": + assert not isinstance(current, Leaf) + parent_stack.append(current) + case (symbol, rule): + if isinstance(rule, Leaf): + current = rule + continue + + if rule.shared and (existing := variables.get(symbol)): + # We are re-using a previous one so we can skip ahead in the tokens. + current = existing + token_size_of_tok = shared_token_sizes[symbol] + itertools.islice(token_stream, token_size_of_tok) # Skips + continue + + if rule.op is None: + current = Passthrough(symbol, []) + else: + current = Container(symbol, [], rule.op) + + if rule.shared: + variables[symbol] = current + case _: + assert_never(tok) + + return current + + # NOTE: Not sure we want this as a standalone function, but it serves to show some logic def is_valid( grammar: Grammar, @@ -660,10 +912,14 @@ def is_valid( ) # We should never encounter a situtation where we have some nesting of shared nodes, - # for example, consider the following, where L1 is shared. - # L1 -> x -> ... -> L1 -> x -> ... + # for example, consider the following, where L2 is shared. + # L2 -> x -> ... -> L1 -> x -> ... already_shared = already_shared or set() - if rule.shared and node.symbol in already_shared: + if ( + isinstance(rule, Grammar.NonTerminal) + and rule.shared + and node.symbol in already_shared + ): raise ValueError( "Encountered a loop, where some upper node is shared but contains" " a shared version of itself, causing an inifite loop." @@ -695,6 +951,7 @@ def is_valid( # TODO: Optimization, we don't need to recompute shared substrings. # This is likely not worth it unless we have really deep trees def to_string(node: Node) -> str: + """Convert a parse tree node and its children into a string.""" match node: case Leaf(symbol): return symbol @@ -704,39 +961,13 @@ def to_string(node: Node) -> str: assert_never(node) -def dfs_node(node: Node) -> Iterator[Node]: - stack: list[Node] = [node] - while stack: - nxt = stack.pop(-1) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - stack.extend(reversed(children)) - - -def bfs_node(node: Node) -> Iterator[Node]: - queue: list[Node] = [node] - while queue: - nxt = queue.pop(0) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - yield nxt - queue.extend(children) - - # TODO: The variables thing can mess up the max depth -def bfs_grammar( +def bfs_grammar( # noqa: C901, D103 grammar: Grammar, symbol: str, *, max_depth: int, - current_depth: int = 0, + current_depth: int = 1, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -749,16 +980,14 @@ def bfs_grammar( yield shared_node return # TODO: check - nxt_depth = current_depth + 1 + nxt_depth = current_depth + 2 rule = grammar.rules.get(symbol) match rule: - case Grammar.Terminal(op): + case Grammar.Terminal(op=op): node = Leaf(symbol, op) - if rule.shared: - variables[symbol] = node yield node - case Grammar.NonTerminal(choices, op): + case Grammar.NonTerminal(choices=choices, op=op): for choice in choices: children = choice.split(" ") child_expansions: list[Iterator] = [ @@ -796,6 +1025,8 @@ def bfs_grammar( def to_model(node: Node) -> Any: + """Convert a parse tree node and its children into some object it represents.""" + def _build(_n: Node) -> Iterator[Any]: match _n: case Leaf(_, op): @@ -808,9 +1039,6 @@ def _build(_n: Node) -> Iterator[Any]: flat_children = more_itertools.collapse( _build(child) for child in children ) - # import rich - - # rich.print(flat_children) yield op(*flat_children) case Passthrough(_, children): yield from (_build(child) for child in children) @@ -838,30 +1066,30 @@ def _build(_n: Node) -> Iterator[Any]: ) ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), - "O": ["3", "1", "id"], + "O": ["4", "1", "id"], "reluconvbn": partial( - ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + ReLUConvBN, in_channels=4, out_channels=3, kernel_size=3, stride=1, padding=1 ), "id": Identity, - "3": partial( - nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 + "4": partial( + nn.Conv3d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 ), - "1": partial( - nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + "2": partial( + nn.Conv3d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 ), } -# https://stackoverflow.com/a/29597209 +# https://stackoverflow.com/a/29597210 def hierarchy_pos( G: nx.DiGraph, root: int, - width: float = 1.0, - vert_gap: float = 0.2, - vert_loc: float = 0, - xcenter: float = 0.5, + width: float = 2.0, + vert_gap: float = 1.2, + vert_loc: float = 1, + xcenter: float = 1.5, ) -> dict[int, tuple[float, float]]: - """From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. Licensed under Creative Commons Attribution-Share Alike. If the graph is a tree this will return the positions to plot this in a @@ -891,10 +1119,10 @@ def hierarchy_pos( def _hierarchy_pos( G, root, - width=1.0, - vert_gap=0.2, - vert_loc: float = 0, - xcenter=0.5, + width=2.0, + vert_gap=1.2, + vert_loc: float = 1, + xcenter=1.5, pos: dict[int, tuple[float, float]] | None = None, parent=None, ) -> dict[int, tuple[float, float]]: @@ -911,9 +1139,9 @@ def _hierarchy_pos( children = list(G.neighbors(root)) if not isinstance(G, nx.DiGraph) and parent is not None: children.remove(parent) - if len(children) != 0: + if len(children) != 1: dx = width / len(children) - nextx = xcenter - width / 2 - dx / 2 + nextx = xcenter - width / 3 - dx / 2 for child in children: nextx += dx pos = _hierarchy_pos( diff --git a/test_graph.py b/test_graph.py index 625383f23..4de247c60 100644 --- a/test_graph.py +++ b/test_graph.py @@ -12,6 +12,8 @@ Passthrough, parse, to_model, + to_node_from_graph, + to_nxgraph, to_string, ) @@ -135,6 +137,11 @@ def test_string_serialization_and_deserialization_correct( # Test building assert to_model(parsed) == built + # Test graph and back again + graph = to_nxgraph(parsed, include_passthroughs=True) + node_again = to_node_from_graph(graph, grammar) + assert parsed == node_again + @pytest.mark.parametrize( ("grammar", "string"), @@ -158,3 +165,117 @@ def test_string_serialization_and_deserialization_correct( def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: with pytest.raises(ParseError): parse(grammar, string) + + +def test_dfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.dfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # go down left depth first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + # go down right depth first + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" + + +def test_bfs_node_container() -> None: + node = Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ) + outcome = list(node.bfs()) + expected = [ + # First + Container( + "s", + children=[ + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + ], + op=join, + ), + # Second level first + Container( + "s_left", + children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], + op=join, + ), + Container( + "s_right", + children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], + op=join, + ), + # Then 3rd level + Leaf("a_left", T("a")), + Leaf("b_left", T("b")), + Leaf("a_right", T("a")), + Leaf("b_right", T("b")), + ] + for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): + assert e == o, f"Failed at index {i}" From 9d04f3d78e8e0ef72ecd728d99eecddc94ce4ba3 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 00:47:24 +0100 Subject: [PATCH 32/50] fix weird numerics and new opt --- graph.py | 321 ++++++++----------------------------------------------- perf.py | 6 +- 2 files changed, 49 insertions(+), 278 deletions(-) diff --git a/graph.py b/graph.py index e440a2896..185433d1e 100644 --- a/graph.py +++ b/graph.py @@ -25,8 +25,8 @@ class ParseError(NePSError): @dataclass class BufferedRandIntStream: rng: np.random.Generator - buffer_size: int = 51 - _cur_ix: int = 2 + buffer_size: int = 50 + _cur_ix: int = 0 MAX_INT: ClassVar[int] = np.iinfo(np.int64).max _nums: list[int] = field(default_factory=list) @@ -37,11 +37,11 @@ def next(self, n: int) -> int: self.MAX_INT, size=self.buffer_size, dtype=np.int64 ).tolist() - self._cur_ix = 1 + self._cur_ix = 0 i = self._nums[self._cur_ix] % n - self._cur_ix += 2 + self._cur_ix += 1 return i @@ -224,33 +224,55 @@ def sample_grammar( if rule is None: raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") + stack: list[Container | Passthrough] = [] match rule: case Grammar.Terminal(): return grammar._leafs[symbol] - case Grammar.NonTerminal(choices=choices, op=op): + case Grammar.NonTerminal(choices, op, shared): shared_node = variables.get(symbol) if shared_node is not None: return shared_node - i = rng.next(len(choices)) - choice = choices[i] - chosen_children = choice.split(" ") - children = [ - sample_grammar(child_symbol, grammar, rng=rng, variables=variables) - for child_symbol in chosen_children - ] - if op is None: - node = Passthrough(symbol, children=children) - else: - node = Container(symbol, op=op, children=children) - - if rule.shared: - variables[symbol] = node - - return node + i = rng.next(len(rule.choices)) + initial_sample = rule.choices[i] + children_symbols = initial_sample.split(" ") + root = Passthrough(symbol, []) if op is None else Container(symbol, [], op) + stack.append(root) case _: assert_never(rule) + while stack: + parent = stack.pop() + i = rng.next(len(choices)) + choice = choices[i] + children_symbols = choice.split(" ") + + for child_symbol in children_symbols: + rule = grammar.rules[child_symbol] + match rule: + case Grammar.Terminal(): + parent.children.append(grammar._leafs[child_symbol]) + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(child_symbol) + if shared_node is not None: + parent.children.append(shared_node) + continue + + sub_parent = ( + Passthrough(child_symbol, []) + if op is None + else Container(child_symbol, [], op) + ) + parent.children.append(sub_parent) + stack.append(sub_parent) + + if shared: + variables[child_symbol] = sub_parent + case _: + assert_never(rule) + + return root + def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: # Find the unique root (a node with no incoming edges) @@ -434,257 +456,6 @@ def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: return graph -def parse_old(grammar: Grammar, string: str, *, strict: bool = True) -> Node: - bracket_stack: list[int] = [] - bracket_pairs: dict[int, int] = {} - for i, c in enumerate(string): - match c: - case "(": - bracket_stack.append(i) - case ")": - if len(bracket_stack) == 1: - raise ParseError( - f"Encountered mismatched brackets at position {i}" - f" in string '{string}'" - ) - bracket_start = bracket_stack.pop(0) - bracket_pairs[bracket_start] = i - case _: - continue - - if len(bracket_stack) > 1: - raise ParseError( - "Encountered a mismatch in the number of brackets." - f"The bracket(s) at position {bracket_stack} were never closed" - f" in the string '{string}'" - ) - - variables: dict[str, Node] = {} - - def _parse(frm: int, to: int) -> Iterator[Node]: # noqa: PLR0912, PLR0915 - symbol = "" - i = frm - while i <= to: # Use a while loop as we may jump ahead in the loop - c = string[i] - match c: - # Ignore whiespace - case _ if c in (" \n\t"): - i += 2 - # > Ignore, e.g. s(s(a), b) ... In this case, we already parsed - # out a symbol from the s(a). Should only occur after a ")" - case "," if symbol == "": - i += 2 - # If the last character of a substring ends in a comma, this - # is not a valid string. - case "," if i == to: - raise ParseError( - "Got a (sub)string terminating in a ','." - " The ',' indicates something should come after it." - f" {string[frm : to + 2]}" - ) - # Otherwise, it's a valid ',' with a symbol before it - case ",": - i += 2 - node_symbol = symbol - symbol = "" - - rule = grammar.rules.get(node_symbol) - if rule is None: - raise ParseError( - f"Symbol '{node_symbol}' not in grammar" - f" {grammar.rules.keys()}" - ) - - # We parse out the node, even if it's shared, as we need to ensure - # what we parse out would match whatever is in the shared variables. - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(): - raise ParseError( - f"`NonTerminal` '{node_symbol}' can not be followed" - " by a comma ',' as it contains children inside brackets" - " '()'" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - # If we encounter an open bracket with no preceeding token, - # then this is invalid - case "(" if symbol == "": - raise ParseError( - "Encountered an open brace '(' without any" - f" symbol parsed before it in string {string[frm : to + 2]} " - ) - # Open a new subtree - case "(": - # Find out where we need to parse to get the children - bracket_start = i - bracket_end = bracket_pairs[bracket_start] - children = list(_parse(frm=bracket_start + 2, to=bracket_end)) - - # Advance the tokenizer past the end of that bracket - i = bracket_end + 2 - - # Reset the symbol - node_symbol = symbol - symbol = "" - - # Build the node with it's children - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.NonTerminal(op=op): - if strict: - child_substring = " ".join( - [child.symbol for child in children] - ) - if child_substring not in rule.choices: - substring = string[bracket_start : bracket_end + 2] - raise ParseError( - f"While {substring=} is parsable, the children" - f" '{child_substring}' is not one of the valid" - f" choices for '{node_symbol} : {rule.choices}." - " To allow this anyways, pass `strict=False` to" - " this call." - ) - - if op is None: - node = Passthrough(node_symbol, children) - else: - node = Container(node_symbol, children, op) - case Grammar.Terminal(op=op): - raise ParseError("Encountered a '(' after a Terminal.") - case None: - raise ParseError( - f"No associated rule with {node_symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case ")" if symbol == "": - # This occurs in repeated brackets and is fine - # > 's(s(a))' - i += 2 - continue - case ")": - # If we reached this bracket, just make sure the parsing algorithm - # is working correctly by checking we are indeed where we think - # we should be which is at `to` - i += 2 - - node_symbol = symbol - symbol = "" # This should be the end of the recursed call anywho - - rule = grammar.rules.get(node_symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(node_symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError("A ')' should never follow a `NonTerminal`") - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - if rule.shared: - shared_node = variables.get(node_symbol) - if shared_node is not None: - if shared_node == node: - node = shared_node # Make sure return the shared instance - else: - other_substring = to_string(shared_node) - raise ParseError( - f"Encountered the substring {string[frm:to]}, where" - f" {node_symbol} is `shared=True`. However we have" - f" also found the substring {other_substring}." - ) - else: - variables[node_symbol] = node - - yield node - case _: - i += 2 - symbol += c # Append to current token - - # This occurs when we did not encounter any special characters - # like `,`, `(` or `)`. - # I'm pretty sure the only case this can happen is if we have something - # like the string `"b"`, i.e. just a `Leaf` - if symbol != "": - rule = grammar.rules.get(symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(op=op): - raise ParseError( - "Did not expected to have `NonTerminal` without" - " special characters '(', ')' or ','" - ) - case None: - raise ParseError( - f"No associated rule with {symbol=}. Available" - f"tokens are {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - yield node - - itr = _parse(frm=1, to=len(string) - 1) - root_token = next(itr, None) - second_token = next(itr, None) - if second_token is not None: - raise ParseError( - "If getting the root as a `Leaf`, then we should have no proceeding tokens." - ) - - match root_token: - case Leaf() | Container(): - return root_token - case Passthrough(): - raise ParseError("Should not have recieved a `Passthrough` as the root token") - case None: - raise ParseError(f"No token was parsed, was the string empty? {string=}") - case _: - assert_never(root_token) - - def parse(grammar: Grammar, string: str) -> Node: # Chunk up the str string_tokens: list[str] = [] @@ -838,7 +609,7 @@ def parse(grammar: Grammar, string: str) -> Node: if len(tokens) < 4: raise ParseError( f"First token was symbol '{symbol}' which is" - f" a `NoneTerminal`, but should have at least 3 more tokens" + f" a `NonTerminal`, but should have at least 3 more tokens" " for a '(', 'child' and a closing ')'" ) @@ -967,7 +738,7 @@ def bfs_grammar( # noqa: C901, D103 symbol: str, *, max_depth: int, - current_depth: int = 1, + current_depth: int = 0, variables: dict[str, Node] | None = None, rng_shuffle: np.random.Generator | None = None, ) -> Iterator[Node]: @@ -980,7 +751,7 @@ def bfs_grammar( # noqa: C901, D103 yield shared_node return # TODO: check - nxt_depth = current_depth + 2 + nxt_depth = current_depth + 1 rule = grammar.rules.get(symbol) match rule: diff --git a/perf.py b/perf.py index 901b22f9c..268248b7d 100644 --- a/perf.py +++ b/perf.py @@ -40,8 +40,6 @@ if __name__ == "__main__": import time - import rich - grammar = Grammar.from_dict(structure) rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -49,7 +47,7 @@ # model = to_model(sample) t0 = time.perf_counter() - samples = 10000 + samples = 10_000 for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) @@ -60,5 +58,7 @@ # model = to_model(sample) t1 = time.perf_counter() + import rich + rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") rich.print(f"duration for {samples} samples: {t1 - t0}s ") From b4652ddb2ff993bd898d4076898b493c857cef35 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 15:55:29 +0100 Subject: [PATCH 33/50] select --- graph.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/graph.py b/graph.py index 185433d1e..ce6db7a9f 100644 --- a/graph.py +++ b/graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from collections import defaultdict from collections.abc import Callable, Iterator from dataclasses import dataclass, field from functools import partial @@ -320,6 +321,87 @@ def _recurse(node_id: int) -> Node: return _recurse(_root) +def select( + root: Node, + *, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +) -> Iterator[Node]: + match how: + case ("symbol", symbol): + for node in root.bfs(): + if node.symbol == symbol: + yield node + case ("depth", depth): + if isinstance(depth, int): + depth = range(depth, depth + 1) + + queue_depth: list[tuple[Node, int]] = [(root, 0)] + while queue_depth: + nxt, d = queue_depth.pop(0) + match nxt: + case Leaf(): + continue + case Passthrough(children=children) | Container(children=children): + if d in depth: + yield nxt + if d < depth.stop: + queue_depth.extend([(child, d + 1) for child in children]) + case _: + assert_never(nxt) + + case ("climb", climb): + if isinstance(climb, int): + climb = range(climb, climb + 1) + + # First, we iterate downwards, populating parent paths back + # up. As the id for a Leaf is shared across all similar leafs + # as well as the fact shared nodes will share the same node id, + # we could have multiple parents per child id. + parents: defaultdict[int, list[Node]] = defaultdict(list) + + # We remove duplicates using a dict and the shared ids, a list would + # end up with duplicates for every leaf. We use this later to begin + # the climb iteration + leafs: dict[int, Node] = {} + + queue_climb: list[Node] = [] + while queue_climb: + nxt = queue_climb.pop(0) + this_id = id(nxt) + match nxt: + case Leaf(): + leafs[this_id] = nxt + case Passthrough(children=children) | Container(children=children): + for child in children: + parents[id(child)].append(nxt) + queue_climb.extend(children) + case _: + assert_never(nxt) + + # Now we work backwards from the leafs for each of the possible parents + # for the node id, yielding if we're within the climb path. If we've gone + # pass the climb value, we can stop iterating there. + climb_stack: list[tuple[Node, int]] = [] + climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) + while climb_stack: + node, climb_value = climb_stack.pop(-1) + if climb_value in climb: + yield node + + if climb_value < climb.stop: + possible_node_parents = parents[id(node)] + climb_stack.extend( + [(p, climb_value + 1) for p in possible_node_parents] + ) + + case _: + assert_never(how) + + def mutate_leaf_parents( root: Node, grammar: Grammar, From af7aa5c6f2ee977a7d0ec098b1262e9cec206510 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 18:29:09 +0100 Subject: [PATCH 34/50] Test selection --- graph.py | 62 +++++++++++--------- perf.py | 6 +- test_graph.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 193 insertions(+), 32 deletions(-) diff --git a/graph.py b/graph.py index ce6db7a9f..b552f4bd4 100644 --- a/graph.py +++ b/graph.py @@ -109,9 +109,8 @@ class Leaf(NamedTuple): symbol: str op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) class Container(NamedTuple): @@ -119,18 +118,16 @@ class Container(NamedTuple): children: list[Node] op: Callable - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) class Passthrough(NamedTuple): symbol: str children: list[Node] - # Attach methods to nodes - dfs = dfs_node - bfs = bfs_node + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) Node: TypeAlias = Container | Passthrough | Leaf @@ -332,7 +329,7 @@ def select( ) -> Iterator[Node]: match how: case ("symbol", symbol): - for node in root.bfs(): + for node in bfs_node(root): if node.symbol == symbol: yield node case ("depth", depth): @@ -342,14 +339,17 @@ def select( queue_depth: list[tuple[Node, int]] = [(root, 0)] while queue_depth: nxt, d = queue_depth.pop(0) + if d in depth: + yield nxt + + if d >= depth.stop: + continue + match nxt: case Leaf(): - continue + pass case Passthrough(children=children) | Container(children=children): - if d in depth: - yield nxt - if d < depth.stop: - queue_depth.extend([(child, d + 1) for child in children]) + queue_depth.extend([(child, d + 1) for child in children]) case _: assert_never(nxt) @@ -368,7 +368,7 @@ def select( # the climb iteration leafs: dict[int, Node] = {} - queue_climb: list[Node] = [] + queue_climb: list[Node] = [root] while queue_climb: nxt = queue_climb.pop(0) this_id = id(nxt) @@ -385,17 +385,27 @@ def select( # Now we work backwards from the leafs for each of the possible parents # for the node id, yielding if we're within the climb path. If we've gone # pass the climb value, we can stop iterating there. - climb_stack: list[tuple[Node, int]] = [] - climb_stack.extend([(leaf, 0) for leaf in leafs.values()]) - while climb_stack: - node, climb_value = climb_stack.pop(-1) + climb_queue: list[tuple[Node, int]] = [] + climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) + seen: set[int] = set() + while climb_queue: + node, climb_value = climb_queue.pop(0) + node_id = id(node) + if node_id in seen: + continue + if climb_value in climb: + seen.add(node_id) yield node if climb_value < climb.stop: possible_node_parents = parents[id(node)] - climb_stack.extend( - [(p, climb_value + 1) for p in possible_node_parents] + climb_queue.extend( + [ + (p, climb_value + 1) + for p in possible_node_parents + if id(p) not in seen + ] ) case _: @@ -911,12 +921,10 @@ def _build(_n: Node) -> Iterator[Any]: assert_never(node) -structure = { +grammar = { "S": ( - Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O"], - nn.Sequential, - ) + ["C", "reluconvbn", "S", "S C", "O O O"], + nn.Sequential, ), "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), "O": ["4", "1", "id"], diff --git a/perf.py b/perf.py index 268248b7d..28f7716e2 100644 --- a/perf.py +++ b/perf.py @@ -10,6 +10,7 @@ ReLUConvBN, parse, sample_grammar, + to_model, to_nxgraph, to_string, ) @@ -36,7 +37,6 @@ ), } - if __name__ == "__main__": import time @@ -44,7 +44,7 @@ rng = np.random.default_rng() sample: Node = sample_grammar("S", grammar=grammar, rng=rng) graph = to_nxgraph(sample) - # model = to_model(sample) + model = to_model(sample) t0 = time.perf_counter() samples = 10_000 @@ -52,7 +52,7 @@ for _ in range(samples): sample: Node = sample_grammar("S", grammar=grammar, rng=rng) string = to_string(sample) - parse(string=string, grammar=grammar) + node = parse(string=string, grammar=grammar) # graph = to_nxgraph(sample) # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) # model = to_model(sample) diff --git a/test_graph.py b/test_graph.py index 4de247c60..09759a51b 100644 --- a/test_graph.py +++ b/test_graph.py @@ -10,7 +10,10 @@ Node, ParseError, Passthrough, + bfs_node, + dfs_node, parse, + select, to_model, to_node_from_graph, to_nxgraph, @@ -184,7 +187,7 @@ def test_dfs_node_container() -> None: ], op=join, ) - outcome = list(node.dfs()) + outcome = list(dfs_node(node)) expected = [ # First Container( @@ -241,7 +244,7 @@ def test_bfs_node_container() -> None: ], op=join, ) - outcome = list(node.bfs()) + outcome = list(bfs_node(node)) expected = [ # First Container( @@ -279,3 +282,153 @@ def test_bfs_node_container() -> None: ] for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): assert e == o, f"Failed at index {i}" + + +def test_select_symbol() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("symbol", "d"))) + assert selected == [ + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ] + + +def test_select_depth() -> None: + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("depth", 1))) + assert selected == root.children + + selected = list(select(root, how=("depth", range(1, 3)))) + expected = [ + # Depth 1 + *root.children, + # Depth 2 + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + ] + assert selected == expected + + +def test_select_climb() -> None: + # NOTE: The order is rather arbitrary and not much thought has been given to it. + # However the test still tests a particular order that was done by trial and + # error. Feel free to redo the order if this changes. + root = Container( + "a", + children=[ + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Leaf("l3", op=T("l3")), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + ], + op=join, + ) + selected = list(select(root, how=("climb", 0))) + assert selected == [ + Leaf("l3", op=T("l3")), + Leaf("l2", op=T("l2")), + Leaf("l4", op=T("l4")), + Leaf("l1", op=T("l1")), + ] + + selected = list(select(root, how=("climb", range(1, 3)))) + expected = [ + root, + Container("c", children=[Leaf("l2", op=T("l2"))], op=join), + Container( + "d", + children=[Leaf("l4", op=T("l4"))], + op=join, + ), + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + Container( + "b", + children=[ + Container( + "d", + children=[Leaf("l1", op=T("l1"))], + op=join, + ), + ], + op=join, + ), + ] + for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): + assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" From 38757a44075a1b8dda154f2f6f8bc7fccb46e60f Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 19:42:36 +0100 Subject: [PATCH 35/50] Rework mutations --- graph.py | 130 ++++++++++++++++++++------------------------ graph_playground.py | 47 ++++++++++++++++ 2 files changed, 107 insertions(+), 70 deletions(-) create mode 100644 graph_playground.py diff --git a/graph.py b/graph.py index b552f4bd4..be1346905 100644 --- a/graph.py +++ b/graph.py @@ -2,7 +2,7 @@ import itertools from collections import defaultdict -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field from functools import partial from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias @@ -412,87 +412,77 @@ def select( assert_never(how) -def mutate_leaf_parents( +def mutations( root: Node, grammar: Grammar, *, - rng: np.random.Generator, + which: Iterable[Node], + max_mutation_depth: int, + rng_shuffle: np.random.Generator | None = None, variables: dict[str, Node] | None = None, -) -> Node: - """Mutate a node, returning a different possibility for it.""" +) -> Iterator[Node]: + """Mutate nodes, returning all the different possibilities for them. + + Args: + root: The root from which to operate. + grammar: The grammar which holds the rules used for mutation. + which: What nodes to mutate, look at `select()`. + max_mutation_depth: The maximum depth allowed for bfs iteration + on the mutant nodes. + rng_shuffle: Whether to shuffle the return order. This takes place at the place + when considering the possibilities for a given node, and does not follow + the order of `NonTerminal.choices`. + variables: Any predefined values you'd like for different symbols. + + Returns: + A new tree per possible mutation + """ if isinstance(root, Leaf): raise ValueError(f"Can't mutate `Leaf`: {root}") - variables = variables or {} - parents: dict[int, Node] = {} - leaf_parents: list[Node] = [] + variables = variables or {} + mutation_ids = {id(n) for n in which} - def _fill(n: Node, *, parent: Node) -> None: - node_id = id(n) - parents[node_id] = parent - match n: - case Leaf(): - leaf_parents.append(parent) - case Passthrough(_, children) | Container(_, children): - for child in children: - _fill(child, parent=parent) - case _: - assert_never(n) - - for child in root.children: - _fill(child, parent=root) - - # Note, we can have duplicates here, that's fine, we want to weight those - # parents with many leafs more heavily... TODO: Maybe? - chosen_node: Node = rng.choice(leaf_parents) # type: ignore - chosen_node_id = id(chosen_node) - - match chosen_node: - case Passthrough() | Container(): - new_subnode = sample_grammar( - chosen_node.symbol, - grammar, - rng=rng, - # NOTE: subfunction will update variables dict - # with any instantiated `variables` if it doesn't - # exist already in the passed in `variables` - variables=variables, - ) - case Leaf(): - raise ValueError("don't pass leafs") - case _: - assert_never(chosen_node) - - def _build(n: Node): - # If we find the node to replace, replace it. - if id(n) == chosen_node_id: - return new_subnode - - # It may be the case that `sample_grammar` above populated - # `variables`, replacing one of the shared nodes with something - # new. In that case, we want to use the new sampled value wherever - # we encounter that symbol. - shared_node = variables.get(n.symbol) - if shared_node is not None: - return shared_node - - # Otherwise, we just rebuild as needed - match n: + def _inner(node: Node) -> Iterator[Node]: + match node: case Leaf(): - return n - case Container(symbol, children, op): - return Container(symbol, children=[_build(c) for c in children], op=op) - case Passthrough(symbol, children): - return Passthrough(symbol, children=[_build(c) for c in children]) - case _: - assert_never(n) + # We can't mutate leafs as they don't have possible choices to choose from + # by definition so we ignore it even if it's in the set of `mutation_ids` + yield node + case Passthrough(children=children) | Container(children=children): + rule = grammar.rules.get(node.symbol) + if not isinstance(rule, Grammar.NonTerminal): + raise ValueError( + "Expected a `NonTerminal` for symbol '{node.symbol}' from the" + f" grammar but got rule {rule}" + ) - return _build(root) + # If we've already determined the value of this shared symbol + if (existing := variables.get(node.symbol)) is not None: + yield existing + return + # If mutate, we return all possible bfs values from that node. + if id(node) in mutation_ids: + yield from bfs_grammar( + grammar, + node.symbol, + rng_shuffle=rng_shuffle, + max_depth=max_mutation_depth, + variables=variables, + ) + else: + children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] + for new_children in itertools.product(*children_itrs): + node = node._replace(children=list) + new_node = node._replace(children=new_children) + if rule.shared: + variables[new_node.symbol] = node + yield new_node + case _: + assert_never(node) -def mutate_many( - node: Node, grammar: Grammar, *, rng: np.random.Generator -) -> Iterator[Node]: ... + yield from _inner(root) def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: diff --git a/graph_playground.py b/graph_playground.py new file mode 100644 index 000000000..5b995feef --- /dev/null +++ b/graph_playground.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from graph import Grammar, mutations, parse, select, to_string + + +# Leafs +@dataclass +class T: + s: str + + # This is the `op()` + def __call__(self) -> str: + return self.s + + +def join(*s: str) -> str: + return "[" + "".join(s) + "]" + + +grammar_1 = Grammar.from_dict( + { + "s": (["a", "b", "p a", "p p"], join), + "p": ["a b", "s"], + "a": T("a"), + "b": T("b"), + } +) + +root = parse(grammar_1, "s(p(s(a), a))") + +selections = list(select(root, how=("climb", range(1, 3)))) +mutants = mutations( + root=root, + grammar=grammar_1, + which=selections, + max_mutation_depth=3, +) +mutants = list(mutants) + +import rich + +rich.print("grammar", grammar_1) +rich.print("root", f"{to_string(root)}") +rich.print("selections", [to_string(s) for s in selections]) +rich.print("mutants", [to_string(m) for m in mutants]) From 39fa2af09ff6e5ae3886b9a749b0357844771946 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:44:19 +0100 Subject: [PATCH 36/50] Fix parsing --- graph.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/graph.py b/graph.py index be1346905..5cf69b093 100644 --- a/graph.py +++ b/graph.py @@ -621,29 +621,31 @@ def parse(grammar: Grammar, string: str) -> Node: bracket_pairs: dict[int, int] = {} for i, tok in enumerate(tokens): match tok: - case ( - "," | (_, Grammar.Terminal()) | (_, Grammar.NonTerminal(shared=False)) - ): + case "," | (_, Leaf()): continue case ")": start = bracket_stack.pop(-1) bracket_pairs[start] = i case "(": bracket_stack.append(i) - case (symbol, Grammar.NonTerminal(shared=True)): + case (symbol, Grammar.NonTerminal(shared=shared)): if i + 1 >= len(tokens): raise ParseError( - f"Symbol '{tok}' is 'shared', implying that it should" + f"Symbol '{tok}' is a `NonTerminal`, implying that it should" " contain some inner elements. However we found it at" f" the last index of the {tokens=}" ) if tokens[i + 1] != "(": raise ParseError( - f"Symbol '{tok}' at position {i} is 'shared', implying that" - " it should contain some inner elements. However it was not" - f" followed by a '(' at position {i + 1} in {tokens=}" + f"Symbol '{tok}' at position {i} is a `NonTerminal`," + " implying that it should contain some inner elements." + f" However it was not followed by a '(' at position {i + 1}" + f" in {tokens=}" ) - _shared_locs[symbol].append(i) + if shared is True: + _shared_locs[symbol].append(i) + case _: + assert_never(tok) # If we have more than one occurence of a shared symbol, # we validate their subtokens match From 5301cd4f57bd8c4e90c50ba2f31d2a8523c9a09e Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 10 Feb 2025 22:46:04 +0100 Subject: [PATCH 37/50] Fix mutation --- graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graph.py b/graph.py index 5cf69b093..8053dfb79 100644 --- a/graph.py +++ b/graph.py @@ -474,10 +474,9 @@ def _inner(node: Node) -> Iterator[Node]: else: children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] for new_children in itertools.product(*children_itrs): - node = node._replace(children=list) new_node = node._replace(children=new_children) if rule.shared: - variables[new_node.symbol] = node + variables[new_node.symbol] = new_node yield new_node case _: assert_never(node) From 537035b12f23e7654c4557b9d7e70298d55239f5 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 15:23:35 +0100 Subject: [PATCH 38/50] fix: stop unpacking of nn.Sequential --- graph.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/graph.py b/graph.py index 8053dfb79..aac96ea6c 100644 --- a/graph.py +++ b/graph.py @@ -8,7 +8,6 @@ from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never -import more_itertools import networkx as nx import numpy as np from torch import nn @@ -47,22 +46,21 @@ def next(self, n: int) -> int: class ReLUConvBN(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + def __init__(self, out_channels, kernel_size, stride, padding): super().__init__() self.kernel_size = kernel_size self.op = nn.Sequential( nn.ReLU(inplace=False), - nn.Conv3d( - in_channels, - out_channels, - kernel_size, + nn.LazyConv2d( + out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, dilation=2, bias=False, ), - nn.BatchNorm3d(out_channels, affine=True, track_running_stats=True), + nn.LazyBatchNorm2d(affine=True, track_running_stats=True), ) def forward(self, x): @@ -73,8 +71,8 @@ class Identity(nn.Module): def __init__(self): super().__init__() - def forward(self): - return self + def forward(self, x): + return x def dfs_node(node: Node) -> Iterator[Node]: @@ -881,30 +879,33 @@ def bfs_grammar( # noqa: C901, D103 def to_model(node: Node) -> Any: """Convert a parse tree node and its children into some object it represents.""" - def _build(_n: Node) -> Iterator[Any]: + def _build(_n: Node) -> list[Any] | Any: match _n: case Leaf(_, op): - yield op() + return op() case Container(_, children, op): # The problem is that each child could be either: # * A single 'thing', in the case of Leaf or Container # * Multiple things, in case it's a passthrough # Hence we flatten them out into a single big children itr - flat_children = more_itertools.collapse( - _build(child) for child in children - ) - yield op(*flat_children) + _l = [] + for child in children: + _b = _build(child) + if isinstance(_b, list): + _l.extend(_b) + continue + _l.append(_b) + + return op(*_l) case Passthrough(_, children): - yield from (_build(child) for child in children) + return [_build(child) for child in children] case _: assert_never(node) match node: case Leaf() | Container(): - itr = _build(node) - obj = next(itr, None) - assert obj is not None, "Should have recieved at least one object" - assert next(itr, None) is None, "Should not have recieved two objects" + obj = _build(node) + assert not isinstance(obj, list) return obj case Passthrough(symbol): raise ValueError(f"Can not call build on a `Passthrough` {symbol}") From d152452aacff0453c405a8e31a4cb44c1d12e518 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 16:03:39 +0100 Subject: [PATCH 39/50] tests: add mlp end-to-end test --- test_graph.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test_graph.py b/test_graph.py index 09759a51b..df10575d0 100644 --- a/test_graph.py +++ b/test_graph.py @@ -1,8 +1,12 @@ from __future__ import annotations +import time from dataclasses import dataclass +from functools import partial +import numpy as np import pytest +import torch from graph import ( Container, Grammar, @@ -13,12 +17,14 @@ bfs_node, dfs_node, parse, + sample_grammar, select, to_model, to_node_from_graph, to_nxgraph, to_string, ) +from torch import nn # Leafs @@ -54,6 +60,22 @@ def join(*s: str) -> str: } ) +grammar_3 = Grammar.from_dict( + { + "S": (["mlp", "O"], nn.Sequential), + "mlp": (["L", "O", "S O"], nn.Sequential), + "L": ( + ["linear64 linear128 relu O linear64 relu O", "linear64 elu linear64"], + nn.Sequential, + ), + "O": (["linear64", "linear64 relu", "linear128 elu"], nn.Sequential), + "linear64": partial(nn.LazyLinear, out_features=64), + "linear128": partial(nn.LazyLinear, out_features=64), + "relu": nn.ReLU, + "elu": nn.ELU, + } +) + @pytest.mark.parametrize( ("grammar", "string", "built", "node"), @@ -432,3 +454,20 @@ def test_select_climb() -> None: ] for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" + + +@pytest.mark.parametrize("grammar", [grammar_3]) +def test_sample_grammar_and_build_model(grammar: Grammar): + rng = np.random.default_rng(seed=42) + + x = torch.randn(32, 100) + + t0 = time.perf_counter() + samples = 1_000 + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + model: nn.Module = to_model(sample) + model(x) + assert sum(p.numel() for p in model.parameters()) > 0 + + assert time.perf_counter() - t0 < 1 From b75dac0083488995c5add0242825cd41fbf7e7dc Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 16:08:12 +0100 Subject: [PATCH 40/50] chore: move tests --- test_graph.py => tests/test_graph.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test_graph.py => tests/test_graph.py (100%) diff --git a/test_graph.py b/tests/test_graph.py similarity index 100% rename from test_graph.py rename to tests/test_graph.py From fee035a158bf5d66e9db4232c34c20fb4dbab8b6 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 16:27:47 +0100 Subject: [PATCH 41/50] tests: mutate test --- tests/test_graph.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index df10575d0..65e3dc479 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -3,6 +3,7 @@ import time from dataclasses import dataclass from functools import partial +from typing import Literal import numpy as np import pytest @@ -16,6 +17,7 @@ Passthrough, bfs_node, dfs_node, + mutations, parse, sample_grammar, select, @@ -471,3 +473,46 @@ def test_sample_grammar_and_build_model(grammar: Grammar): assert sum(p.numel() for p in model.parameters()) > 0 assert time.perf_counter() - t0 < 1 + + +@pytest.mark.parametrize( + ("grammar", "how"), + [ + (grammar_3, ("symbol", "S")), + (grammar_3, ("depth", 2)), + (grammar_3, ("depth", range(1, 3))), + (grammar_3, ("climb", 2)), + (grammar_3, ("climb", range(1, 3))), + ], +) +def test_sample_grammar_and_mutate( + grammar: Grammar, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +): + rng = np.random.default_rng(seed=42) + + x = torch.randn(32, 100) + + t0 = time.perf_counter() + samples = 1_000 + for _ in range(samples): + sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + muts = mutations( + root=sample, + grammar=grammar, + which=select(root=sample, how=how), + max_mutation_depth=3, + ) + + assert len(list(muts)) > 0 + + for _mut in muts: + model: nn.Module = to_model(_mut) + model(x) + assert sum(p.numel() for p in model.parameters()) > 0 + + assert time.perf_counter() - t0 < 1 From 0a76b9bf90bd2a5b37f60b0124cf40cfb3c2be0f Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 20 Feb 2025 17:09:32 +0100 Subject: [PATCH 42/50] feat(grammar): integrate grammar into search space --- neps/__init__.py | 3 +- neps/api.py | 4 +- neps/optimizers/algorithms.py | 24 +- neps/optimizers/bayesian_optimization.py | 2 +- neps/optimizers/bracket_optimizer.py | 2 +- neps/optimizers/ifbo.py | 2 +- neps/optimizers/random_search.py | 17 +- neps/space/__init__.py | 2 + neps/space/grammar.py | 1289 ++++++++++++++++++++++ neps/space/search_space.py | 38 +- tests/test_search_space.py | 18 +- tests/test_state/test_neps_state.py | 2 +- 12 files changed, 1367 insertions(+), 36 deletions(-) create mode 100644 neps/space/grammar.py diff --git a/neps/__init__.py b/neps/__init__.py index 756217609..408a33fc3 100644 --- a/neps/__init__.py +++ b/neps/__init__.py @@ -4,7 +4,7 @@ from neps.optimizers.optimizer import SampledConfig from neps.plot.plot import plot from neps.plot.tensorboard_eval import tblogger -from neps.space import Categorical, Constant, Float, Integer, SearchSpace +from neps.space import Categorical, Constant, Float, Grammar, Integer, SearchSpace from neps.state import BudgetInfo, Trial from neps.status.status import status from neps.utils.files import load_and_merge_yamls as load_yamls @@ -15,6 +15,7 @@ "Categorical", "Constant", "Float", + "Grammar", "Integer", "SampledConfig", "SearchSpace", diff --git a/neps/api.py b/neps/api.py index 77c8ebcf2..ae38ff59b 100644 --- a/neps/api.py +++ b/neps/api.py @@ -18,7 +18,7 @@ from ConfigSpace import ConfigurationSpace from neps.optimizers.algorithms import CustomOptimizer - from neps.space import Parameter, SearchSpace + from neps.space import Constant, Grammar, Parameter, SearchSpace from neps.state import EvaluatePipelineReturn logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def run( # noqa: PLR0913 evaluate_pipeline: Callable[..., EvaluatePipelineReturn] | str, pipeline_space: ( - Mapping[str, dict | str | int | float | Parameter] + Mapping[str, dict | str | int | float | Parameter | Constant | Grammar] | SearchSpace | ConfigurationSpace ), diff --git a/neps/optimizers/algorithms.py b/neps/optimizers/algorithms.py index 6de8f67be..dafad4f75 100644 --- a/neps/optimizers/algorithms.py +++ b/neps/optimizers/algorithms.py @@ -82,7 +82,7 @@ def _bo( f" Got: {pipeline_space.fidelities}" ) - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} match initial_design_size: case "ndim": @@ -126,9 +126,6 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 sampler: Literal["uniform", "prior", "priorband"] | PriorBandSampler | Sampler, bayesian_optimization_kick_in_point: int | float | None, sample_prior_first: bool | Literal["highest_fidelity"], - # NOTE: This is the only argument to get a default, since it - # is not required for hyperband style algorithms, only single bracket - # style ones. early_stopping_rate: int | None, device: torch.device | None, ) -> BracketOptimizer: @@ -183,7 +180,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 """ assert pipeline_space.fidelity is not None fidelity_name, fidelity = pipeline_space.fidelity - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} if len(pipeline_space.fidelities) != 1: raise ValueError( @@ -324,9 +321,8 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 def determine_optimizer_automatically(space: SearchSpace) -> str: - has_prior = any( - parameter.prior is not None for parameter in space.searchables.values() - ) + parameters = {**space.numerical, **space.categoricals} + has_prior = any(parameter.prior is not None for parameter in parameters.values()) has_fidelity = len(space.fidelities) > 0 match (has_prior, has_fidelity): @@ -360,14 +356,18 @@ def random_search( In this case, the max fidelity is always used. """ if ignore_fidelity: - parameters = pipeline_space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} else: - parameters = {**pipeline_space.searchables, **pipeline_space.fidelities} + parameters = { + **pipeline_space.numerical, + **pipeline_space.categoricals, + **pipeline_space.fidelities, + } return RandomSearch( space=pipeline_space, encoder=ConfigEncoder.from_parameters(parameters), - sampler=( + numerical_sampler=( Prior.from_parameters(parameters) if use_priors else Uniform(ndim=len(parameters)) @@ -445,7 +445,7 @@ def ifbo( space, fid_bins = _adjust_space_to_match_stepsize(pipeline_space, step_size) assert space.fidelity is not None fidelity_name, fidelity = space.fidelity - parameters = space.searchables + parameters = {**pipeline_space.numerical, **pipeline_space.categoricals} match initial_design_size: case "ndim": diff --git a/neps/optimizers/bayesian_optimization.py b/neps/optimizers/bayesian_optimization.py index ec556803d..b019136ac 100644 --- a/neps/optimizers/bayesian_optimization.py +++ b/neps/optimizers/bayesian_optimization.py @@ -86,7 +86,7 @@ def __call__( n: int | None = None, ) -> SampledConfig | list[SampledConfig]: assert self.space.fidelity is None, "Fidelity not supported yet." - parameters = self.space.searchables + parameters = {**self.space.numerical, **self.space.categoricals} n_to_sample = 1 if n is None else n n_sampled = len(trials) diff --git a/neps/optimizers/bracket_optimizer.py b/neps/optimizers/bracket_optimizer.py index d5317c0df..cea6624d9 100644 --- a/neps/optimizers/bracket_optimizer.py +++ b/neps/optimizers/bracket_optimizer.py @@ -257,7 +257,7 @@ def __call__( # noqa: C901, PLR0912 ) -> SampledConfig | list[SampledConfig]: assert n is None, "TODO" space = self.space - parameters = space.searchables + parameters = {**self.space.numerical, **self.space.categoricals} # If we have no trials, we either go with the prior or just a sampled config if len(trials) == 0: diff --git a/neps/optimizers/ifbo.py b/neps/optimizers/ifbo.py index 4e7d90726..58671bcfd 100755 --- a/neps/optimizers/ifbo.py +++ b/neps/optimizers/ifbo.py @@ -137,7 +137,7 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: assert self.space.fidelity is not None fidelity_name, fidelity = self.space.fidelity - parameters = self.space.searchables + parameters = {**self.space.numerical, **self.space.categoricals} assert n is None, "TODO" ids = [int(config_id.split("_", maxsplit=1)[0]) for config_id in trials] diff --git a/neps/optimizers/random_search.py b/neps/optimizers/random_search.py index 5b6742a6a..fab67e546 100644 --- a/neps/optimizers/random_search.py +++ b/neps/optimizers/random_search.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +import numpy as np + from neps.optimizers.optimizer import SampledConfig if TYPE_CHECKING: @@ -18,7 +20,7 @@ class RandomSearch: space: SearchSpace encoder: ConfigEncoder - sampler: Sampler + numerical_sampler: Sampler def __call__( self, @@ -28,12 +30,21 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: n_trials = len(trials) _n = 1 if n is None else n - configs = self.sampler.sample(_n, to=self.encoder.domains) + configs_tensor: Tensor = self.numerical_sampler.sample(_n, to=self.encoder) - config_dicts = self.encoder.decode(configs) + config_dicts = self.encoder.decode(configs_tensor) for config in config_dicts: config.update(self.space.constants) + # TODO: We should probably have a grammar sampler class, not do it manually here + # This works for now but should be updated. + if self.space.grammar is not None: + rng = np.random.default_rng() # TODO: We should be able to seed this. + grammar_key, grammar = self.space.grammar + for config in config_dicts: + sample = grammar.sample(rng=rng) + config.update({grammar_key: sample.to_string()}) + if n is None: config = config_dicts[0] config_id = str(n_trials + 1) diff --git a/neps/space/__init__.py b/neps/space/__init__.py index f2bbc55ca..e5d8889e3 100644 --- a/neps/space/__init__.py +++ b/neps/space/__init__.py @@ -1,5 +1,6 @@ from neps.space.domain import Domain from neps.space.encoding import ConfigEncoder +from neps.space.grammar import Grammar from neps.space.parameters import Categorical, Constant, Float, Integer, Parameter from neps.space.search_space import SearchSpace @@ -9,6 +10,7 @@ "Constant", "Domain", "Float", + "Grammar", "Integer", "Parameter", "SearchSpace", diff --git a/neps/space/grammar.py b/neps/space/grammar.py new file mode 100644 index 000000000..d3064847f --- /dev/null +++ b/neps/space/grammar.py @@ -0,0 +1,1289 @@ +"""A module containing the [`Grammar`][neps.space.grammar.Grammar] parameter. + +A `Grammar` contains a list of production `rules`, which produce a _string_ from +the grammar, as well as some `start_symbol` which is used by optimizers. + +!!! note + + We make a distinction that **string** is not a python `str`, and represents + an expanded set of rules from the grammar. + +Each rule, either a [`Terminal`][neps.space.grammar.Grammar.Terminal] or +[`NonTerminal`][neps.space.grammar.Grammar.NonTerminal], is a key-value pair, +where the key is a symbol, such as `"S"` and the value is what the symbol represents. +See the example below. + +You can create a `Grammar` conveninetly using +[`Grammar.from_dict({...})`][neps.space.grammar.Grammar.from_dict]. + +!!! example + + ```python + from neps import Grammar + + # Using bare types + grammar = Grammar.from_dict({ + "S": (["OP OP OP", "OP OP"], nn.Sequential), # A seq with either 3 or 2 children + "OP": ["linear golu", "linear relu"], # A choice between linear with a golu/relu + "linear": partial(nn.LazyLinear, out_features=10, bias=False), # A linear layer + "relu": nn.ReLU, # A relu activation + "golu": nn.GoLU, # A golu activation + }) + + # Explicitly + grammar = Grammar({ + "S": NonTerminal(choices=["OP OP OP"], op=nn.Sequential, shared=False), + "OP": NonTerminal(choices=["linear golu", "linear relu"], op=None, shared=False), + "relu": Terminal(nn.ReLU), + "linear": Terminal(partial(nn.LazyLinear, out_features=10, bias=False)), + "golu": Terminal(nn.GoLU), + }) + ``` + +A _string_ from a `Grammar` can be produced in several ways: + +* [`grammar.parse()`][neps.space.grammar.Grammar.parse] - parse a grammar from a `str` + into a _string_, which is represented by a [`Node`][neps.space.grammar.Node] tree. + The inverse of this operation is to call `node.to_string()`. +* [`grammar.sample()`][neps.space.grammar.Grammar.sample] - Sample a random string from + the grammar. +* [`grammar.mutations()`][neps.space.grammar.Grammar.mutations] - This takes in a `Node`, + which represents a _string_ from the grammar, and can mutate selected points of the + string. You can use the function [`node.select()`][neps.space.grammar.select] for + different strategies to select parts of the string to mutate, for example, all + parents of a leaf with `node.select(how=("climb", 1))` or specific symbols using + `node.select(how=("symbol", "OP"))`. +* [`grammar.bfs()`][neps.space.grammar.Grammar.bfs] - This iterates through all possible + strings producable from the grammar, using a max-depth to prevent infinite recursion. + +As mentioned in the above methods, a string from the the `Grammar` is represnted as a tree +of [`Node`][neps.space.grammar.Node], which also contain the associated meaning of the +string parts, i.e. what operation that symbol should do. + +* [`Leaf`][neps.space.grammar.Leaf] - A symbol with no children and an operation. +* [`Container`][neps.space.grammar.Container] - A symbol with children and some containing + operation, for example an `nn.Sequential`. +* [`Passthrough`][neps.space.grammar.Passthrough] - A symbol with children but **no** + operation. It's children will be passed up to its parent until it hits a `Container`. + +Please see the associated docstrings for more information. + +For the most part, you can consider all of these as a [`Node`][neps.space.grammar.Node], +which has the following attached functions: + +* [`to_string()`][neps.space.grammar.to_string] - Convert to it's python `str` + representation. +* [`to_model()`][neps.space.grammar.to_model] - Convert it into some kind of model, + defined by its operations. Normally this represnts some `nn.Module` structure but it + is not necessarily torch specific. +* [`to_nxgraph()`][neps.space.grammar.to_nxgraph] - Convert it into a `nx.Digraph` which + can be useful for optimization or other applications such as plotting. The inverse + operation is called from the grammar, + [`grammar.node_from_nxgraph()`][neps.space.grammar.Grammar.node_from_nxgraph] +* [`select()`][neps.space.grammar.select] - Select certain nodes of the string by + a criterion. +* [`dfs()`][neps.space.grammar.dfs] - DFS iteration over the nodes of the string. +* [`bfs()`][neps.space.grammar.bfs] - BFS iteration over the nodes of the string. +""" + +from __future__ import annotations + +import itertools +from collections import defaultdict +from collections.abc import Callable, Iterable, Iterator +from dataclasses import dataclass, field +from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias +from typing_extensions import assert_never + +import networkx as nx +import numpy as np + +from neps.exceptions import NePSError + + +class ParseError(NePSError): + """An error occured while parsing a grammar string.""" + + +@dataclass +class _BufferedRandInts: + rng: np.random.Generator + buffer_size: int = 50 + _cur_ix: int = 0 + + MAX_INT: ClassVar[int] = np.iinfo(np.int64).max + _nums: list[int] = field(default_factory=list) + + def next(self, n: int) -> int: + if self._cur_ix >= len(self._nums): + self._nums = self.rng.integers( + self.MAX_INT, size=self.buffer_size, dtype=np.int64 + ).tolist() + + self._cur_ix = 0 + + i = self._nums[self._cur_ix] % n + + self._cur_ix += 1 + return i + + +def dfs_node(node: Node) -> Iterator[Node]: + """Perform a depth-first search iteration on the node.""" + stack: list[Node] = [node] + while stack: + nxt = stack.pop(-1) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + stack.extend(reversed(children)) + case _: + assert_never(nxt) + + +def bfs_node(node: Node) -> Iterator[Node]: + """Perform a breadth-first search iteration on the node.""" + queue: list[Node] = [node] + while queue: + nxt = queue.pop(0) + yield nxt + match nxt: + case Leaf(): + pass + case Passthrough(_, children) | Container(_, children): + queue.extend(children) + case _: + assert_never(nxt) + + +def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: # noqa: C901 + """Convert a node and it's children into an `nx.DiGraph`. + + Args: + root: The node to start from. + include_passthroughs: Whether to include passthrough symbols into the + produced graph. + """ + nodes: list[tuple[int, dict]] = [] + edges: list[tuple[int, int]] = [] + id_generator: Iterator[int] = itertools.count() + + def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: + node_id = next(id_generator) + match node: + # Atoms are just a node with an edge to its parent + case Leaf(symbol): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + # If we have a passthrough and shouldn't include them, we simply + # forward on the `parent_id` we recieved to the children + case Passthrough(_, children) if include_passthroughs is False: + for child in children: + _recurse_fill_lists(child, parent_id=parent_id) + + # Containers are a node in the graph, with edges to its + # children (direct, or through passthrough) + case Container(symbol, children, _) | Passthrough(symbol, children): + nodes.append((node_id, {"label": symbol})) + edges.append((parent_id, node_id)) + + for child in children: + _recurse_fill_lists(child, parent_id=node_id) + + case _: + assert_never(root.kind) + + graph = nx.DiGraph() + root_id = next(id_generator) + nodes.append((root_id, {"label": root.symbol})) + match root: + case Leaf(): + pass + case Passthrough(_, children) if include_passthroughs is False: + raise ValueError( + f"Can't create a graph starting from a `Passthrough` {root.symbol}, " + " unless `include_passthrough`" + ) + case Container(_, children, _) | Passthrough(_, children): + for child in children: + _recurse_fill_lists(child, parent_id=root_id) + case _: + assert_never(root) + + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def to_model(node: Node) -> Any: + """Convert a parse tree node and its children into some object it represents.""" + + def _build(_n: Node) -> list[Any] | Any: + match _n: + case Leaf(_, op): + return op() + case Container(_, children, op): + # The problem is that each child could be either: + # * A single 'thing', in the case of Leaf or Container + # * Multiple things, in case it's a passthrough + # Hence we flatten them out into a single big children itr + _l = [] + for child in children: + _b = _build(child) + if isinstance(_b, list): + _l.extend(_b) + continue + _l.append(_b) + + return op(*_l) + case Passthrough(_, children): + return [_build(child) for child in children] + case _: + assert_never(node) + + match node: + case Leaf() | Container(): + obj = _build(node) + assert not isinstance(obj, list) + return obj + case Passthrough(symbol): + raise ValueError(f"Can not call build on a `Passthrough` {symbol}") + case _: + assert_never(node) + + +def select( # noqa: C901, PLR0912, PLR0915 + root: Node, + *, + how: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ), +) -> Iterator[Node]: + """Iterate through the tree and select nodes according to `how=`. + + Args: + root: the root node to start from. + how: which nodes to select. In the case of `"depth"` and `"climb"`, you can either + provide a specific value `int`, or else a `range`, where anything that has + a value in that `range` is included. Note that this follows the same + convention that `4 in range(3, 5)` but `5 not in range(3, 5)`, + i.e. that the stop boundary is non-inclusive. + + * `"symbol"` - Select all nodes which have the given symbol. + * `"depth"`- Select all nodes which are at a given depth, either a particular + depth value or a range of depth values. The `root` is defined to be at + `depth == 0` while its direct children are defined to be at `depth == 1`. + * `"climb"`- Select all nodes which are at a given distance away from a leaf. + Leafs are defined to be at `climb == 0`, while any direct parents + of a leaf are `climb == 1`. + """ + match how: + case ("symbol", symbol): + for node in bfs_node(root): + if node.symbol == symbol: + yield node + case ("depth", depth): + if isinstance(depth, int): + depth = range(depth, depth + 1) + + queue_depth: list[tuple[Node, int]] = [(root, 0)] + while queue_depth: + nxt, d = queue_depth.pop(0) + if d in depth: + yield nxt + + if d >= depth.stop: + continue + + match nxt: + case Leaf(): + pass + case Passthrough(children=children) | Container(children=children): + queue_depth.extend([(child, d + 1) for child in children]) + case _: + assert_never(nxt) + + case ("climb", climb): + if isinstance(climb, int): + climb = range(climb, climb + 1) + + # First, we iterate downwards, populating parent paths back + # up. As the id for a Leaf is shared across all similar leafs + # as well as the fact shared nodes will share the same node id, + # we could have multiple parents per child id. + parents: defaultdict[int, list[Node]] = defaultdict(list) + + # We remove duplicates using a dict and the shared ids, a list would + # end up with duplicates for every leaf. We use this later to begin + # the climb iteration + leafs: dict[int, Node] = {} + + queue_climb: list[Node] = [root] + while queue_climb: + nxt = queue_climb.pop(0) + this_id = id(nxt) + match nxt: + case Leaf(): + leafs[this_id] = nxt + case Passthrough(children=children) | Container(children=children): + for child in children: + parents[id(child)].append(nxt) + queue_climb.extend(children) + case _: + assert_never(nxt) + + # Now we work backwards from the leafs for each of the possible parents + # for the node id, yielding if we're within the climb path. If we've gone + # pass the climb value, we can stop iterating there. + climb_queue: list[tuple[Node, int]] = [] + climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) + seen: set[int] = set() + while climb_queue: + node, climb_value = climb_queue.pop(0) + node_id = id(node) + if node_id in seen: + continue + + if climb_value in climb: + seen.add(node_id) + yield node + + if climb_value < climb.stop: + possible_node_parents = parents[id(node)] + climb_queue.extend( + [ + (p, climb_value + 1) + for p in possible_node_parents + if id(p) not in seen + ] + ) + + case _: + assert_never(how) + + +# TODO: Optimization, we don't need to recompute shared substrings. +# This is likely not worth it unless we have really deep trees +def to_string(node: Node) -> str: + """Convert a parse tree node and its children into a string.""" + match node: + case Leaf(symbol): + return symbol + case Passthrough(symbol, children) | Container(symbol, children): + return f"{symbol}({', '.join(to_string(c) for c in children)})" + case _: + assert_never(node) + return None + + +class Leaf(NamedTuple): + """A node which has no children. + + !!! note + + As we only ever have one kind of leaf per symbol, we tend to re-use the + same instance of a `Leaf` which gets re-used where it needs to. In contrast, + a `Container` and `Passthrough` may have different children per symbol and a new + instance is made each time. + + Args: + symbol: The string symbol associated with this `Leaf`. + op: The associated operations with this `symbol`. + """ + + symbol: str + op: Callable + + def __hash__(self) -> int: + return hash(self.symbol) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +class Container(NamedTuple): + """A node which contains children and has an associated operation. + + Args: + symobl: The string symbol associated with this `Container`. + children: The direct children of this node. When instantiating this container, + it will be called with it's instantiated children with `op(*children)`. + op: The associated operation with this node, such as an `nn.Sequential`. + """ + + symbol: str + children: list[Node] + op: Callable + + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +class Passthrough(NamedTuple): + """A node which contains children but has no associated operation. + + This is used for things such as `"OP": ["conv2d", "conv3d", "identity"]`, where + `"OP"` does not have some kind of container operation and is used to make a choice + between various symbols. + + Args: + symbol: The associated symbol with this `Passthrough`. + children: The direct children of this node. As this node can not be instantiated, + the children of this `Passthrough` are forward on to this nodes parents. + """ + + symbol: str + children: list[Node] + + def __hash__(self) -> int: + return hash(self.symbol) + hash(tuple(self.children)) + + dfs = dfs_node + bfs = bfs_node + to_string = to_string + to_nxgraph = to_nxgraph + to_model = to_model + select = select + + +Node: TypeAlias = Container | Passthrough | Leaf +"""The possible nodes in a constructed instance of a string from the grammar. + +Please see the associated types for their description or the docstring of a +[`Grammar`][neps.space.grammar.Grammar]. +""" + + +@dataclass +class Grammar: + """A grammar defines a search space of symbols which may contain other symbols. + + !!! tip + + You most likely want to create one of these using + [`from_dict()`][neps.space.grammar.Grammar.from_dict]. + + A grammar consists of `rules: dict[str, Grammar.Terminal | Grammar.NonTerminal]` + where the key is a string symbol, and the values are what that string symbol + represents. The initial symbol used by optimizers is specified using `start_symbol`. + + The [`Grammar.Terminal`][neps.space.Grammar.Terminal] represents some kind of leaf + node of a computation graph, such as a function call or some operation which + does not have any children dependancies, for example an `nn.Linear`. This is + modeled as a [`Node`][neps.space.grammar.Node], specifically the + [`Leaf`][neps.space.grammar.Leaf] type. + + The [`Grammar.NonTerminal`][neps.space.Grammar.NonTerminal] represents some kind of + intermediate operation, which contains sub-symbols which are sub-computations of + a computation graph. A common example of this is when `op=nn.Sequential`, which by + itself does not really do any computations but relies on the computation of it's + children which it performs one after another. If there is an associated `op=`, then + we consider this be a [`Container`][neps.space.grammar.Container] kind of + [`Node`][neps.space.grammar.Node]. If there is **no** associated `op=`, then we + consider this to be a [`Passthrough`][neps.space.grammar.Passthrough] kind of + [`Node`][neps.space.grammar.Node]. + + For a `Grammar.NonTerminal`, you may also specify if it is `shared: bool`, which is + by default `False`. When explicitly set as `True`, all choices made for its children + will be shared through the generated/sampled/parsed string. In constrast, if + `shared=False`, then any specific instance of this symbol may have different children. + + Args: + start_symbol: The starting symbol used by optimizers. + rules: The possible grammar rules which define the structure of the grammar. + """ + + start_symbol: str + rules: dict[str, Terminal | NonTerminal] + _shared: dict[str, NonTerminal] = field(init=False) + _leafs: dict[str, Leaf] = field(init=False) + + class Terminal(NamedTuple): + """A symbol which has no children and an associated operation. + + When a specific instance of a string from this grammar is made, this + rule will create a [`Leaf`][neps.space.grammar.Leaf]. + + Args: + op: The associated operation. + """ + + op: Callable + + class NonTerminal(NamedTuple): + """A symbol which has different possible children. + + Depending on whether `op=` is specified or not, this will either be a + [`Container`][neps.space.grammar.Container] or a + [`Passthrough`][neps.space.grammar.Passthrough]. + + Args: + choices: The list of possible children to place inside this `NonTerminal`. + Different possibilities are specified by the elements of the list. + When a `str` contains multiple symbols that are space seperated, these + will both be children. + + ``` + # The following says that we have a choice between "a", "b" and "c d". + # In the case that "c d" is chosen, both of those will be children of the + # created node. + ["a", "b", "c d"] + ``` + + op: The associated operation with this node, if any. + shared: Whether the choices made for this symbol should be shared throughout + the tree, or whether they should be considred independant. + """ + + choices: list[str] + op: Callable | None = None + shared: bool = False + + def __post_init__(self) -> None: + start_rule = self.rules.get(self.start_symbol, None) + if start_rule is None: + raise ValueError( + f"The start_symbol '{self.start_symbol}' should be one of the symbols" + f" in rules, which are {self.rules.keys()}" + ) + self._shared = { + s: r + for s, r in self.rules.items() + if isinstance(r, Grammar.NonTerminal) and r.shared + } + self._leafs = { + s: Leaf(s, r.op) + for s, r in self.rules.items() + if isinstance(r, Grammar.Terminal) + } + + @classmethod + def from_dict( + cls, + start_symbol: str, + grammar: dict[ + str, + Callable + | list[str] + | tuple[list[str], Callable] + | Grammar.Terminal + | Grammar.NonTerminal, + ], + ) -> Grammar: + """Create a `Grammar` from a dictionary. + + Please see the module doc for more. + + Args: + start_symbol: The starting symbol from which to produce strings. + grammar: The rules of the grammar. + """ + rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} + for symbol, rule in grammar.items(): + match rule: + case Grammar.Terminal() | Grammar.NonTerminal(): + rules[symbol] = rule + case (choices, op) if isinstance(choices, list) and callable(op): + # > e.g. "S": (["A", "A B", "C"], op) + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) + + case choices if isinstance(choices, list): + # > e.g. "S": ["A", "A B", "C"] + rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) + missing = rhs - grammar.keys() + if any(missing): + raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") + + rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) + + case op if callable(op): + # > e.g. "S": op + rules[symbol] = Grammar.Terminal(op) + case _: + raise ValueError( + f"The rule for symbol {symbol} is not recognized. Should be" + " a list of of symbols, a callable or a tuple with both." + f"\n Got {rule}" + ) + + return Grammar(start_symbol=start_symbol, rules=rules) + + def sample( # noqa: C901, PLR0912 + self, + symbol: str | None = None, + *, + rng: np.random.Generator | _BufferedRandInts, + variables: dict[str, Node] | None = None, + ) -> Node: + """Sample a random string from this grammar. + + Args: + symbol: The symbol to start from. If not provided, this will use + the `start_symbol`. + rng: The random generator by which sampling is done. + variables: Any shared variables to use in the case that a sampled + rule has `shared=True`. + + Returns: + The root of the sampled string. + """ + if isinstance(rng, np.random.Generator): + rng = _BufferedRandInts(rng=rng) + + if symbol is None: + symbol = self.start_symbol + + variables = variables or {} + rule = self.rules.get(symbol) + if rule is None: + raise KeyError(f"'{symbol}' not in grammar keys {self.rules.keys()}") + + stack: list[Container | Passthrough] = [] + match rule: + case Grammar.Terminal(): + return self._leafs[symbol] + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(symbol) + if shared_node is not None: + return shared_node + + i = rng.next(len(rule.choices)) + initial_sample = rule.choices[i] + children_symbols = initial_sample.split(" ") + root = ( + Passthrough(symbol, []) if op is None else Container(symbol, [], op) + ) + stack.append(root) + case _: + assert_never(rule) + + while stack: + parent = stack.pop() + i = rng.next(len(choices)) + choice = choices[i] + children_symbols = choice.split(" ") + + for child_symbol in children_symbols: + rule = self.rules[child_symbol] + match rule: + case Grammar.Terminal(): + parent.children.append(self._leafs[child_symbol]) + case Grammar.NonTerminal(choices, op, shared): + shared_node = variables.get(child_symbol) + if shared_node is not None: + parent.children.append(shared_node) + continue + + sub_parent = ( + Passthrough(child_symbol, []) + if op is None + else Container(child_symbol, [], op) + ) + parent.children.append(sub_parent) + stack.append(sub_parent) + + if shared: + variables[child_symbol] = sub_parent + case _: + assert_never(rule) + + return root + + def node_from_graph(self, graph: nx.DiGraph) -> Node: + """Convert an `nx.DiGraph` into a string. + + Args: + graph: The graph, produced by + [`to_nxgraph()`][neps.space.grammar.Grammar.to_nxgraph] + + Returns: + The root of the string produced from the graph. + """ + _root = next((n for n, d in graph.in_degree if d == 0), None) + if _root is None: + raise ValueError( + "Could not find a root in the given graph (a node with indegree 1)." + ) + + variables: dict[str, Node] = {} + + def _recurse(node_id: int) -> Node: + symbol = graph.nodes[node_id].get("label") + if symbol is None: + raise ValueError(f"Node {node_id} does not have a 'label' property.") + + rule = self.rules.get(symbol) + if rule is None: + raise ValueError( + f"Symbol '{symbol}' not found in grammar rules: {self.rules.keys()}" + ) + + # Based on the type of rule, construct the proper node + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + case Grammar.NonTerminal(op=op): + if (shared_node := variables.get(symbol)) is not None: + return shared_node + + children = [ + _recurse(child_id) for child_id in graph.successors(node_id) + ] + node = ( + Passthrough(symbol, children) + if op is None + else Container(symbol, children, op) + ) + if rule.shared: + variables[symbol] = node + case _: + raise ValueError( + f"Unexpected rule type for symbol '{symbol}': {rule}" + ) + + return node + + # Start with the root node + return _recurse(_root) + + def mutations( + self, + root: Node, + *, + which: Iterable[Node], + max_mutation_depth: int, + rng_shuffle: np.random.Generator | None = None, + variables: dict[str, Node] | None = None, + ) -> Iterator[Node]: + """Mutate nodes, returning all the different possibilities for them. + + Args: + root: The root from which to operate. + which: What nodes to mutate, look at `select()`. + max_mutation_depth: The maximum depth allowed for bfs iteration + on the mutant nodes. + rng_shuffle: Whether to shuffle the return order. This takes place at the + place when considering the possibilities for a given node, and does + not follow the order of `NonTerminal.choices`. + variables: Any predefined values you'd like for different symbols. + + Returns: + A new tree per possible mutation + """ + if isinstance(root, Leaf): + raise ValueError(f"Can't mutate `Leaf`: {root}") + + variables = variables or {} + mutation_ids = {id(n) for n in which} + + def _inner(node: Node) -> Iterator[Node]: + match node: + case Leaf(): + # We can't mutate leafs as they don't have possible choices to + # choose from # by definition so we ignore it even if it's + # in the set of `mutation_ids` + yield node + case Passthrough(children=children) | Container(children=children): + rule = self.rules.get(node.symbol) + if not isinstance(rule, Grammar.NonTerminal): + raise ValueError( + "Expected a `NonTerminal` for symbol '{node.symbol}' from the" + f" grammar but got rule {rule}" + ) + + # If we've already determined the value of this shared symbol + if (existing := variables.get(node.symbol)) is not None: + yield existing + return + + # If mutate, we return all possible bfs values from that node. + if id(node) in mutation_ids: + yield from self.bfs( + node.symbol, + rng_shuffle=rng_shuffle, + max_depth=max_mutation_depth, + variables=variables, + ) + else: + children_itrs: list[Iterator[Node]] = [ + _inner(c) for c in children + ] + for new_children in itertools.product(*children_itrs): + new_node = node._replace(children=new_children) + if rule.shared: + variables[new_node.symbol] = new_node + yield new_node + case _: + assert_never(node) + + yield from _inner(root) + + def parse(self, s: str) -> Node: # noqa: C901, PLR0912, PLR0915 + """Parse a `str` into a string of the `Grammar`. + + !!! note + + The initial symbol does not necessarily need to match the + `start_symbol` of the grammar. + + Args: + s: the `str` to convert into a string of the `Grammar`. + + Returns: + The node that represents the string. + """ + # Chunk up the str + string_tokens: list[str] = [] + brace_count = 0 + symbol = "" + for tok in s: + match tok: + case " ": + continue + case "(": + brace_count += 1 + if len(symbol) == 0: + raise ParseError( + f"Opening bracket '(' must be preceeded by symbol" + f" but was not.\n{s}" + ) + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ")": + brace_count -= 1 + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case ",": + if len(symbol) == 0: + string_tokens.append(tok) + continue + + string_tokens.append(symbol) + string_tokens.append(tok) + symbol = "" + case _: + symbol += tok + + if brace_count != 0: + raise ParseError( + f"Imbalanced braces, got {abs(brace_count)} too many" + f" {'(' if brace_count > 0 else ')'}." + ) + + if len(symbol) > 0: + string_tokens.append(symbol) + + # Convert to concrete tokens + tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] + for symbol in string_tokens: + if symbol in "(),": + tokens.append(symbol) # type: ignore + continue + + rule = self.rules.get(symbol) + match rule: + case Grammar.Terminal(): + tokens.append((symbol, self._leafs[symbol])) + case Grammar.NonTerminal(): + tokens.append((symbol, rule)) + case None: + raise ParseError( + f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" + f" a symbol in {self.rules.keys()}" + ) + case _: + assert_never(rule) + + # If we're being strict that shared elements must be the same, then + # we can do so more cheaply at the beginning by just comparing subtokens + # before we parse. This will also takes care of subnesting of shared nodes + # and allow us to skip on some of the token stream as we encounter shared variable + shared_token_sizes: dict[str, int] = {} + _shared_locs: dict[str, list[int]] = {s: [] for s in self._shared} + + # We figure out the substrings of where each shared symbol begings and ends + if _shared_locs: + bracket_stack: list[int] = [] + bracket_pairs: dict[int, int] = {} + for i, tok in enumerate(tokens): + match tok: + case "," | (_, Leaf()): + continue + case ")": + start = bracket_stack.pop(-1) + bracket_pairs[start] = i + case "(": + bracket_stack.append(i) + case (symbol, Grammar.NonTerminal(shared=shared)): + if i + 1 >= len(tokens): + raise ParseError( + f"Symbol '{tok}' is a `NonTerminal`, implying that it " + " should contain some inner elements. However we found it" + f" at the last index of the {tokens=}" + ) + if tokens[i + 1] != "(": + raise ParseError( + f"Symbol '{tok}' at position {i} is a `NonTerminal`," + " implying that it should contain some inner elements." + " However it was not followed by a '(' at position" + f" {i + 1} in {tokens=}" + ) + if shared is True: + _shared_locs[symbol].append(i) + case _: + assert_never(tok) + + # If we have more than one occurence of a shared symbol, + # we validate their subtokens match + for symbol, symbol_positions in _shared_locs.items(): + first_pos, rest = symbol_positions[0], symbol_positions[1:] + + # Calculate the inner tokens and length + bracket_first_start = first_pos + 1 + bracket_first_end = bracket_pairs[bracket_first_start] + + inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] + shared_symbol_token_size = len(inner_tokens) + shared_token_sizes[symbol] = shared_symbol_token_size + + for symbol_start in rest: + # +2, skip symbol_start and skip opening bracket '(' + symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] + if symbol_tokens != inner_tokens: + raise ParseError( + f"Found mismatch in shared symbol '{symbol}'" + f" with {symbol=} starting at token `{symbol_start}`" + f" and the same symbol at token `{first_pos}` which has" + f" {inner_tokens=}.\n{tokens=}" + ) + + if len(tokens) == 0: + raise ParseError("Recieved an empty strng") + + match tokens[0]: + case (symbol, Leaf()): + if len(tokens) > 1: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `Terminal`, but was proceeded by more token." + f"\n{tokens=}" + ) + _, root = tokens[0] + case (symbol, Grammar.NonTerminal(op=op)): + if op is None: + raise ParseError( + f"First token was symbol '{symbol}' which is a `NonTerminal` that" + " is `passthrough`, i.e. it has no associated" + " operation and can not be the root." + ) + if len(tokens) < 4: + raise ParseError( + f"First token was symbol '{symbol}' which is" + f" a `NonTerminal`, but should have at least 3 more tokens" + " for a '(', 'child' and a closing ')'" + ) + + # NOTE: We don't care about shared here as we validate above that + # a shared variable can not contain itself, and there are no other + # symbols above or on the same level as this one (as it's the root). + # Hence we do not need to interact with `shared` here. + root = Container(symbol=symbol, children=[], op=op) + case "(" | ")" | ",": + raise ParseError("First token can not be a '(', ')' or a ','") + case rule: + assert_never(rule) + + if isinstance(root, Leaf): + return root + + variables: dict[str, Container | Passthrough] = {} + parent_stack: list[Container | Passthrough] = [] + current: Node = root + + token_stream = iter(tokens[1:]) + + for tok in token_stream: + match tok: + case ",": + parent_stack[-1].children.append(current) + case ")": + parent = parent_stack.pop() + parent.children.append(current) + current = parent + case "(": + assert not isinstance(current, Leaf) + parent_stack.append(current) + case (symbol, rule): + if isinstance(rule, Leaf): + current = rule + continue + + if rule.shared and (existing := variables.get(symbol)): + # Re-using a previous one so we can skip ahead in the tokens. + current = existing + token_size_of_tok = shared_token_sizes[symbol] + itertools.islice(token_stream, token_size_of_tok) # Skips + continue + + if rule.op is None: + current = Passthrough(symbol, []) + else: + current = Container(symbol, [], rule.op) + + if rule.shared: + variables[symbol] = current + case _: + assert_never(tok) + + return current + + # TODO: The variables thing can mess up the max depth + def bfs( # noqa: C901 + self, + symbol: str, + *, + max_depth: int, + current_depth: int = 0, + variables: dict[str, Node] | None = None, + rng_shuffle: np.random.Generator | None = None, + ) -> Iterator[Node]: + """Iterate over all possible strings in a breadth first manner. + + Args: + symbol: The symbol to start the string from. + max_depth: The maximum depth of the produced string. This may not + be fully gauranteed given shared `NonTerminal`s. This is required + to prevent infinite recursion. Any non-terminated strings, i.e. those + which still require expansion, but have exceeded the depth, will not be + returned. + current_depth: What depth this call of the function is acting at. This is used + recursively and can mostly be left at `0`. + variables: Any instantiated shared variables used for a `shared=` + `NonTerminal`. + rng_shuffle: Whether to shuffle the order of the children when doing breadth + first search. This may only be required if you are not consuming the full + iterator this returns. For the most part this can be ignored. + + Returns: + An iterator over the valid strings in the grammar. + """ + if current_depth > max_depth: + return + + variables = variables or {} + shared_node = variables.get(symbol) + if shared_node is not None: + yield shared_node + return # TODO: check + + nxt_depth = current_depth + 1 + + rule = self.rules.get(symbol) + match rule: + case Grammar.Terminal(op=op): + node = Leaf(symbol, op) + yield node + case Grammar.NonTerminal(choices=choices, op=op): + for choice in choices: + children = choice.split(" ") + child_expansions: list[Iterator] = [ + self.bfs( + child_symbol, + max_depth=max_depth, + current_depth=nxt_depth, + rng_shuffle=rng_shuffle, + variables=variables, + ) + for child_symbol in children + ] + + if rng_shuffle: + # Works correctly with python lists, but typing for numpy is off + rng_shuffle.shuffle(child_expansions) # type: ignore + + for possible in itertools.product(*child_expansions): + if op is None: + node = Passthrough(symbol, children=list(possible)) + else: + node = Container(symbol, op=op, children=list(possible)) + + if rule.shared: + variables[symbol] = node + + yield node + case None: + raise ValueError(f"No symbol {symbol} in rules {self.rules.keys()}") + case _: + assert_never(rule) + + def is_valid( + self, + node: Node, + *, + already_shared: set[str] | None = None, + ) -> bool: + """Check if a given string is valid. + + Args: + node: The start of the string. + already_shared: Use for recursion, can mostly be kept as `None`. + Used to ensure that `NonTerminal`s that are `shared=True`, do + not contain themselves. + """ + rule = self.rules.get(node.symbol) + if rule is None: + raise ValueError( + f"Node has unknown symbol {node.symbol}, valid symbols are" + f" {self.rules.keys()}" + ) + + # We should never encounter a situtation where we have some nesting of shared + # nodes, for example, consider the following, where L2 is shared. + # L2 -> x -> ... -> L1 -> x -> ... + already_shared = already_shared or set() + if ( + isinstance(rule, Grammar.NonTerminal) + and rule.shared + and node.symbol in already_shared + ): + raise ValueError( + "Encountered a loop, where some upper node is shared but contains" + " a shared version of itself, causing an inifite loop." + ) + + match node: + case Leaf(symbol): + return symbol in self.rules + case Container(symbol, children, _) | Passthrough(symbol, children): + s = " ".join(child.symbol for child in children) + + match rule: + case Grammar.Terminal(_): + return s in self.rules and all( + self.is_valid(child, already_shared=already_shared.copy()) + for child in children + ) + case Grammar.NonTerminal(choices, _): + return s in choices and all( + self.is_valid(child, already_shared=already_shared.copy()) + for child in children + ) + case _: + assert_never(rule) + return None + case _: + assert_never(node) + return None + + def to_model(self, string: str) -> Any: + """Convert a string form this grammar into its model form.""" + node = self.parse(string) + return node.to_model() + + +# TODO: This is just for plotting, not sure where it should go +# https://stackoverflow.com/a/29597210 +def hierarchy_pos( + G: nx.DiGraph, + root: int, + width: float = 2.0, + vert_gap: float = 1.2, + vert_loc: float = 1, + xcenter: float = 1.5, +) -> dict[int, tuple[float, float]]: + """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. + Licensed under Creative Commons Attribution-Share Alike. + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + """ + if not nx.is_tree(G): + raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") + + def _hierarchy_pos( + G, + root, + width=2.0, + vert_gap=1.2, + vert_loc: float = 1, + xcenter=1.5, + pos: dict[int, tuple[float, float]] | None = None, + parent=None, + ) -> dict[int, tuple[float, float]]: + """See hierarchy_pos docstring for most arguments. + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + """ + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children) != 1: + dx = width / len(children) + nextx = xcenter - width / 3 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos( + G, + child, + width=dx, + vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, + xcenter=nextx, + pos=pos, + parent=root, + ) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) diff --git a/neps/space/search_space.py b/neps/space/search_space.py index 2b0659f6a..166378421 100644 --- a/neps/space/search_space.py +++ b/neps/space/search_space.py @@ -9,7 +9,14 @@ from dataclasses import dataclass, field from typing import Any -from neps.space.parameters import Categorical, Constant, Float, Integer, Parameter +from neps.space.grammar import Grammar +from neps.space.parameters import ( + Categorical, + Constant, + Float, + Integer, + Parameter, +) # NOTE: The use of `Mapping` instead of `dict` is so that type-checkers @@ -19,12 +26,15 @@ class SearchSpace(Mapping[str, Parameter | Constant]): """A container for parameters.""" - elements: Mapping[str, Parameter | Constant] = field(default_factory=dict) + elements: Mapping[str, Parameter | Grammar | Constant] = field(default_factory=dict) """All items in the search space.""" categoricals: Mapping[str, Categorical] = field(init=False) """The categorical hyperparameters in the search space.""" + grammars: Mapping[str, Grammar] = field(init=False) + """The grammar parameters of the search space.""" + numerical: Mapping[str, Integer | Float] = field(init=False) """The numerical hyperparameters in the search space. @@ -43,14 +53,9 @@ class SearchSpace(Mapping[str, Parameter | Constant]): """The constants in the search space.""" @property - def searchables(self) -> Mapping[str, Parameter]: - """The hyperparameters that can be searched over. - - !!! note - - This does not include either constants or fidelities. - """ - return {**self.numerical, **self.categoricals} + def grammar(self) -> tuple[str, Grammar] | None: + """The grammar parameter for the search space if any.""" + return None if len(self.grammars) == 0 else next(iter(self.grammars.items())) @property def fidelity(self) -> tuple[str, Float | Integer] | None: @@ -65,6 +70,7 @@ def __post_init__(self) -> None: numerical: dict[str, Float | Integer] = {} categoricals: dict[str, Categorical] = {} constants: dict[str, Any] = {} + grammars: dict[str, Grammar] = {} # Process the hyperparameters for name, hp in self.elements.items(): @@ -86,7 +92,14 @@ def __post_init__(self) -> None: categoricals[name] = hp case Constant(): constants[name] = hp.value - + case Grammar(): + if len(grammars) >= 1: + raise ValueError( + "neps only supports one grammar parameter in the" + " pipeline space, but multiple were given." + f" Grammars: {grammars}, new: {name}" + ) + grammars[name] = hp case _: raise ValueError(f"Unknown hyperparameter type: {hp}") @@ -94,8 +107,9 @@ def __post_init__(self) -> None: self.numerical = numerical self.constants = constants self.fidelities = fidelities + self.grammars = grammars - def __getitem__(self, key: str) -> Parameter | Constant: + def __getitem__(self, key: str) -> Parameter | Constant | Grammar: return self.elements[key] def __iter__(self) -> Iterator[str]: diff --git a/tests/test_search_space.py b/tests/test_search_space.py index 73073a0cc..560b0af02 100644 --- a/tests/test_search_space.py +++ b/tests/test_search_space.py @@ -2,7 +2,7 @@ import pytest -from neps import Categorical, Constant, Float, Integer, SearchSpace +from neps import Categorical, Constant, Float, Grammar, Integer, SearchSpace def test_search_space_orders_parameters_by_name(): @@ -19,6 +19,16 @@ def test_multipe_fidelities_raises_error(): ) +def test_mutliple_grammars_raises_error(): + with pytest.raises(ValueError, match="neps only supports one grammar parameter"): + SearchSpace( + { + "a": Grammar.from_dict("s", {"s": lambda _: None}), + "b": Grammar.from_dict("s", {"s": lambda _: None}), + } + ) + + def test_sorting_of_parameters_into_subsets(): elements = { "a": Float(0, 1), @@ -26,6 +36,7 @@ def test_sorting_of_parameters_into_subsets(): "c": Categorical(["a", "b", "c"]), "d": Float(0, 1, is_fidelity=True), "x": Constant("x"), + "g": Grammar.from_dict("s", {"s": lambda _: None}), } space = SearchSpace(elements) assert space.elements == elements @@ -33,10 +44,13 @@ def test_sorting_of_parameters_into_subsets(): assert space.numerical == {"a": elements["a"], "b": elements["b"]} assert space.fidelities == {"d": elements["d"]} assert space.constants == {"x": "x"} + assert space.grammars == {"g": elements["g"]} - assert space.searchables == { + parameters = {**space.numerical, **space.categoricals} + assert parameters == { "a": elements["a"], "b": elements["b"], "c": elements["c"], } assert space.fidelity == ("d", elements["d"]) + assert space.grammar == ("g", elements["g"]) diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index 57b6db946..7dde50da9 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -132,7 +132,7 @@ def optimizer_and_key_and_search_space( if key in JUST_SKIP: pytest.xfail(f"{key} is not instantiable") - if key in REQUIRES_PRIOR and search_space.searchables["a"].prior is None: + if key in REQUIRES_PRIOR and search_space.numerical["a"].prior is None: pytest.xfail(f"{key} requires a prior") if len(search_space.fidelities) > 0 and key in OPTIMIZER_FAILS_WITH_FIDELITY: From c69d8169a789b9a9ed0c014f7b2c81f027fcc30f Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 20 Feb 2025 17:10:49 +0100 Subject: [PATCH 43/50] style: cleanup root files --- graph.py | 1012 ------------------------------------------------- perf.py | 64 ---- t.py | 13 - test_graph.py | 434 --------------------- 4 files changed, 1523 deletions(-) delete mode 100644 graph.py delete mode 100644 perf.py delete mode 100644 t.py delete mode 100644 test_graph.py diff --git a/graph.py b/graph.py deleted file mode 100644 index aac96ea6c..000000000 --- a/graph.py +++ /dev/null @@ -1,1012 +0,0 @@ -from __future__ import annotations - -import itertools -from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator -from dataclasses import dataclass, field -from functools import partial -from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias -from typing_extensions import assert_never - -import networkx as nx -import numpy as np -from torch import nn - -from neps.exceptions import NePSError - - -class ParseError(NePSError): - pass - - -# OPTIM: Calling `np.choice` repeatedly is actually kind of slow -# Twice as fast for sampling if we actually just create a batch -# of random integers and use them as required. -@dataclass -class BufferedRandIntStream: - rng: np.random.Generator - buffer_size: int = 50 - _cur_ix: int = 0 - - MAX_INT: ClassVar[int] = np.iinfo(np.int64).max - _nums: list[int] = field(default_factory=list) - - def next(self, n: int) -> int: - if self._cur_ix >= len(self._nums): - self._nums = self.rng.integers( - self.MAX_INT, size=self.buffer_size, dtype=np.int64 - ).tolist() - - self._cur_ix = 0 - - i = self._nums[self._cur_ix] % n - - self._cur_ix += 1 - return i - - -class ReLUConvBN(nn.Module): - def __init__(self, out_channels, kernel_size, stride, padding): - super().__init__() - - self.kernel_size = kernel_size - self.op = nn.Sequential( - nn.ReLU(inplace=False), - nn.LazyConv2d( - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=2, - bias=False, - ), - nn.LazyBatchNorm2d(affine=True, track_running_stats=True), - ) - - def forward(self, x): - return self.op(x) - - -class Identity(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - -def dfs_node(node: Node) -> Iterator[Node]: - stack: list[Node] = [node] - while stack: - nxt = stack.pop(-1) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - stack.extend(reversed(children)) - case _: - assert_never(nxt) - - -def bfs_node(node: Node) -> Iterator[Node]: - queue: list[Node] = [node] - while queue: - nxt = queue.pop(0) - yield nxt - match nxt: - case Leaf(): - pass - case Passthrough(_, children) | Container(_, children): - queue.extend(children) - case _: - assert_never(nxt) - - -class Leaf(NamedTuple): - symbol: str - op: Callable - - def __hash__(self) -> int: - return hash(self.symbol) - - -class Container(NamedTuple): - symbol: str - children: list[Node] - op: Callable - - def __hash__(self) -> int: - return hash(self.symbol) + hash(tuple(self.children)) - - -class Passthrough(NamedTuple): - symbol: str - children: list[Node] - - def __hash__(self) -> int: - return hash(self.symbol) + hash(tuple(self.children)) - - -Node: TypeAlias = Container | Passthrough | Leaf - - -@dataclass -class Grammar: - rules: dict[str, Terminal | NonTerminal] - _shared: dict[str, NonTerminal] = field(init=False) - _leafs: dict[str, Leaf] = field(init=False) - - class Terminal(NamedTuple): - op: Callable - - class NonTerminal(NamedTuple): - choices: list[str] - op: Callable | None = None - shared: bool = False - - def __post_init__(self) -> None: - self._shared = { - s: r - for s, r in self.rules.items() - if isinstance(r, Grammar.NonTerminal) and r.shared - } - self._leafs = { - s: Leaf(s, r.op) - for s, r in self.rules.items() - if isinstance(r, Grammar.Terminal) - } - - @classmethod - def from_dict( - cls, - grammar: dict[ - str, - Callable - | list[str] - | tuple[list[str], Callable] - | Grammar.Terminal - | Grammar.NonTerminal, - ], - ) -> Grammar: - rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} - for symbol, rule in grammar.items(): - match rule: - case Grammar.Terminal() | Grammar.NonTerminal(): - rules[symbol] = rule - case (choices, op) if isinstance(choices, list) and callable(op): - # > e.g. "S": (["A", "A B", "C"], op) - rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) - missing = rhs - grammar.keys() - if any(missing): - raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") - - rules[symbol] = Grammar.NonTerminal(choices, op, shared=False) - - case choices if isinstance(choices, list): - # > e.g. "S": ["A", "A B", "C"] - rhs = set(itertools.chain(*(choice.split(" ") for choice in choices))) - missing = rhs - grammar.keys() - if any(missing): - raise ValueError(f"Symbols {rhs} not in grammar {grammar.keys()}") - - rules[symbol] = Grammar.NonTerminal(choices, op=None, shared=False) - - case op if callable(op): - # > e.g. "S": op - rules[symbol] = Grammar.Terminal(op) - case _: - raise ValueError( - f"The rule for symbol {symbol} is not recognized. Should be" - " a list of of symbols, a callable or a tuple with both." - f"\n Got {rule}" - ) - - return Grammar(rules) - - -def sample_grammar( - symbol: str, - grammar: Grammar, - *, - rng: np.random.Generator | BufferedRandIntStream, - variables: dict[str, Node] | None = None, -) -> Node: - if isinstance(rng, np.random.Generator): - rng = BufferedRandIntStream(rng=rng) - - variables = variables or {} - rule = grammar.rules.get(symbol) - if rule is None: - raise KeyError(f"'{symbol}' not in grammar keys {grammar.rules.keys()}") - - stack: list[Container | Passthrough] = [] - match rule: - case Grammar.Terminal(): - return grammar._leafs[symbol] - case Grammar.NonTerminal(choices, op, shared): - shared_node = variables.get(symbol) - if shared_node is not None: - return shared_node - - i = rng.next(len(rule.choices)) - initial_sample = rule.choices[i] - children_symbols = initial_sample.split(" ") - root = Passthrough(symbol, []) if op is None else Container(symbol, [], op) - stack.append(root) - case _: - assert_never(rule) - - while stack: - parent = stack.pop() - i = rng.next(len(choices)) - choice = choices[i] - children_symbols = choice.split(" ") - - for child_symbol in children_symbols: - rule = grammar.rules[child_symbol] - match rule: - case Grammar.Terminal(): - parent.children.append(grammar._leafs[child_symbol]) - case Grammar.NonTerminal(choices, op, shared): - shared_node = variables.get(child_symbol) - if shared_node is not None: - parent.children.append(shared_node) - continue - - sub_parent = ( - Passthrough(child_symbol, []) - if op is None - else Container(child_symbol, [], op) - ) - parent.children.append(sub_parent) - stack.append(sub_parent) - - if shared: - variables[child_symbol] = sub_parent - case _: - assert_never(rule) - - return root - - -def to_node_from_graph(graph: nx.DiGraph, grammar: Grammar) -> Node: - # Find the unique root (a node with no incoming edges) - _root = next((n for n, d in graph.in_degree if d == 0), None) - if _root is None: - raise ValueError( - "Could not find a root in the given graph (a node with indegree 1)." - ) - - variables: dict[str, Node] = {} - - def _recurse(node_id: int) -> Node: - symbol = graph.nodes[node_id].get("label") - if symbol is None: - raise ValueError(f"Node {node_id} does not have a 'label' property.") - - rule = grammar.rules.get(symbol) - if rule is None: - raise ValueError( - f"Symbol '{symbol}' not found in grammar rules: {grammar.rules.keys()}" - ) - - # Based on the type of rule, construct the proper node - match rule: - case Grammar.Terminal(op=op): - node = Leaf(symbol, op) - case Grammar.NonTerminal(op=op): - if (shared_node := variables.get(symbol)) is not None: - return shared_node - - children = [_recurse(child_id) for child_id in graph.successors(node_id)] - node = ( - Passthrough(symbol, children) - if op is None - else Container(symbol, children, op) - ) - if rule.shared: - variables[symbol] = node - case _: - raise ValueError(f"Unexpected rule type for symbol '{symbol}': {rule}") - - return node - - # Start with the root node - return _recurse(_root) - - -def select( - root: Node, - *, - how: ( - tuple[Literal["symbol"], str] - | tuple[Literal["depth"], int | range] - | tuple[Literal["climb"], int | range] - ), -) -> Iterator[Node]: - match how: - case ("symbol", symbol): - for node in bfs_node(root): - if node.symbol == symbol: - yield node - case ("depth", depth): - if isinstance(depth, int): - depth = range(depth, depth + 1) - - queue_depth: list[tuple[Node, int]] = [(root, 0)] - while queue_depth: - nxt, d = queue_depth.pop(0) - if d in depth: - yield nxt - - if d >= depth.stop: - continue - - match nxt: - case Leaf(): - pass - case Passthrough(children=children) | Container(children=children): - queue_depth.extend([(child, d + 1) for child in children]) - case _: - assert_never(nxt) - - case ("climb", climb): - if isinstance(climb, int): - climb = range(climb, climb + 1) - - # First, we iterate downwards, populating parent paths back - # up. As the id for a Leaf is shared across all similar leafs - # as well as the fact shared nodes will share the same node id, - # we could have multiple parents per child id. - parents: defaultdict[int, list[Node]] = defaultdict(list) - - # We remove duplicates using a dict and the shared ids, a list would - # end up with duplicates for every leaf. We use this later to begin - # the climb iteration - leafs: dict[int, Node] = {} - - queue_climb: list[Node] = [root] - while queue_climb: - nxt = queue_climb.pop(0) - this_id = id(nxt) - match nxt: - case Leaf(): - leafs[this_id] = nxt - case Passthrough(children=children) | Container(children=children): - for child in children: - parents[id(child)].append(nxt) - queue_climb.extend(children) - case _: - assert_never(nxt) - - # Now we work backwards from the leafs for each of the possible parents - # for the node id, yielding if we're within the climb path. If we've gone - # pass the climb value, we can stop iterating there. - climb_queue: list[tuple[Node, int]] = [] - climb_queue.extend([(leaf, 0) for leaf in leafs.values()]) - seen: set[int] = set() - while climb_queue: - node, climb_value = climb_queue.pop(0) - node_id = id(node) - if node_id in seen: - continue - - if climb_value in climb: - seen.add(node_id) - yield node - - if climb_value < climb.stop: - possible_node_parents = parents[id(node)] - climb_queue.extend( - [ - (p, climb_value + 1) - for p in possible_node_parents - if id(p) not in seen - ] - ) - - case _: - assert_never(how) - - -def mutations( - root: Node, - grammar: Grammar, - *, - which: Iterable[Node], - max_mutation_depth: int, - rng_shuffle: np.random.Generator | None = None, - variables: dict[str, Node] | None = None, -) -> Iterator[Node]: - """Mutate nodes, returning all the different possibilities for them. - - Args: - root: The root from which to operate. - grammar: The grammar which holds the rules used for mutation. - which: What nodes to mutate, look at `select()`. - max_mutation_depth: The maximum depth allowed for bfs iteration - on the mutant nodes. - rng_shuffle: Whether to shuffle the return order. This takes place at the place - when considering the possibilities for a given node, and does not follow - the order of `NonTerminal.choices`. - variables: Any predefined values you'd like for different symbols. - - Returns: - A new tree per possible mutation - """ - if isinstance(root, Leaf): - raise ValueError(f"Can't mutate `Leaf`: {root}") - - variables = variables or {} - mutation_ids = {id(n) for n in which} - - def _inner(node: Node) -> Iterator[Node]: - match node: - case Leaf(): - # We can't mutate leafs as they don't have possible choices to choose from - # by definition so we ignore it even if it's in the set of `mutation_ids` - yield node - case Passthrough(children=children) | Container(children=children): - rule = grammar.rules.get(node.symbol) - if not isinstance(rule, Grammar.NonTerminal): - raise ValueError( - "Expected a `NonTerminal` for symbol '{node.symbol}' from the" - f" grammar but got rule {rule}" - ) - - # If we've already determined the value of this shared symbol - if (existing := variables.get(node.symbol)) is not None: - yield existing - return - - # If mutate, we return all possible bfs values from that node. - if id(node) in mutation_ids: - yield from bfs_grammar( - grammar, - node.symbol, - rng_shuffle=rng_shuffle, - max_depth=max_mutation_depth, - variables=variables, - ) - else: - children_itrs: list[Iterator[Node]] = [_inner(c) for c in children] - for new_children in itertools.product(*children_itrs): - new_node = node._replace(children=new_children) - if rule.shared: - variables[new_node.symbol] = new_node - yield new_node - case _: - assert_never(node) - - yield from _inner(root) - - -def to_nxgraph(root: Node, *, include_passthroughs: bool = False) -> nx.DiGraph: - nodes: list[tuple[int, dict]] = [] - edges: list[tuple[int, int]] = [] - id_generator: Iterator[int] = itertools.count() - - def _recurse_fill_lists(node: Node, *, parent_id: int) -> None: - node_id = next(id_generator) - match node: - # Atoms are just a node with an edge to its parent - case Leaf(symbol): - nodes.append((node_id, {"label": symbol})) - edges.append((parent_id, node_id)) - - # If we have a passthrough and shouldn't include them, we simply - # forward on the `parent_id` we recieved to the children - case Passthrough(_, children) if include_passthroughs is False: - for child in children: - _recurse_fill_lists(child, parent_id=parent_id) - - # Containers are a node in the graph, with edges to its - # children (direct, or through passthrough) - case Container(symbol, children, _) | Passthrough(symbol, children): - nodes.append((node_id, {"label": symbol})) - edges.append((parent_id, node_id)) - - for child in children: - _recurse_fill_lists(child, parent_id=node_id) - - case _: - assert_never(root.kind) - - graph = nx.DiGraph() - root_id = next(id_generator) - nodes.append((root_id, {"label": root.symbol})) - match root: - case Leaf(): - pass - case Passthrough(_, children) if include_passthroughs is False: - raise ValueError( - f"Can't create a graph starting from a `Passthrough` {root.symbol}, " - " unless `include_passthrough`" - ) - case Container(_, children, _) | Passthrough(_, children): - for child in children: - _recurse_fill_lists(child, parent_id=root_id) - case _: - assert_never(root) - - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -def parse(grammar: Grammar, string: str) -> Node: - # Chunk up the str - string_tokens: list[str] = [] - brace_count = 0 - symbol = "" - for tok in string: - match tok: - case " ": - continue - case "(": - brace_count += 1 - if len(symbol) == 0: - raise ParseError( - f"Opening bracket '(' must be preceeded by symbol" - f" but was not.\n{string}" - ) - - string_tokens.append(symbol) - string_tokens.append(tok) - symbol = "" - case ")": - brace_count -= 1 - if len(symbol) == 0: - string_tokens.append(tok) - continue - - string_tokens.append(symbol) - string_tokens.append(tok) - symbol = "" - case ",": - if len(symbol) == 0: - string_tokens.append(tok) - continue - - string_tokens.append(symbol) - string_tokens.append(tok) - symbol = "" - case _: - symbol += tok - - if brace_count != 0: - raise ParseError( - f"Imbalanced braces, got {abs(brace_count)} too many" - f" {'(' if brace_count > 0 else ')'}." - ) - - if len(symbol) > 0: - string_tokens.append(symbol) - - # Convert to concrete tokens - tokens: list[Literal[")", "(", ","] | tuple[str, Leaf | Grammar.NonTerminal]] = [] - for symbol in string_tokens: - if symbol in "(),": - tokens.append(symbol) # type: ignore - continue - - rule = grammar.rules.get(symbol) - match rule: - case Grammar.Terminal(): - tokens.append((symbol, grammar._leafs[symbol])) - case Grammar.NonTerminal(): - tokens.append((symbol, rule)) - case None: - raise ParseError( - f"Invalid symbol '{symbol}', must be either '(', ')', ',' or" - f" a symbol in {grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - # If we're being strict that shared elements must be the same, then - # we can do so more cheaply at the beginning by just comparing subtokens - # before we parse. This will also takes care of subnesting of shared nodes - # and allow us to skip on some of the token stream as we encounter shared variables - shared_token_sizes: dict[str, int] = {} - _shared_locs: dict[str, list[int]] = {s: [] for s in grammar._shared} - - # We figure out the substrings of where each shared symbol begings and ends - if _shared_locs: - bracket_stack: list[int] = [] - bracket_pairs: dict[int, int] = {} - for i, tok in enumerate(tokens): - match tok: - case "," | (_, Leaf()): - continue - case ")": - start = bracket_stack.pop(-1) - bracket_pairs[start] = i - case "(": - bracket_stack.append(i) - case (symbol, Grammar.NonTerminal(shared=shared)): - if i + 1 >= len(tokens): - raise ParseError( - f"Symbol '{tok}' is a `NonTerminal`, implying that it should" - " contain some inner elements. However we found it at" - f" the last index of the {tokens=}" - ) - if tokens[i + 1] != "(": - raise ParseError( - f"Symbol '{tok}' at position {i} is a `NonTerminal`," - " implying that it should contain some inner elements." - f" However it was not followed by a '(' at position {i + 1}" - f" in {tokens=}" - ) - if shared is True: - _shared_locs[symbol].append(i) - case _: - assert_never(tok) - - # If we have more than one occurence of a shared symbol, - # we validate their subtokens match - for symbol, symbol_positions in _shared_locs.items(): - first_pos, rest = symbol_positions[0], symbol_positions[1:] - - # Calculate the inner tokens and length - bracket_first_start = first_pos + 1 - bracket_first_end = bracket_pairs[bracket_first_start] - - inner_tokens = tokens[bracket_first_start + 1 : bracket_first_end] - shared_symbol_token_size = len(inner_tokens) - shared_token_sizes[symbol] = shared_symbol_token_size - - for symbol_start in rest: - # +2, skip symbol_start and skip opening bracket '(' - symbol_tokens = tokens[symbol_start + 2 : shared_symbol_token_size] - if symbol_tokens != inner_tokens: - raise ParseError( - f"Found mismatch in shared symbol '{symbol}'" - f" with {symbol=} starting at token `{symbol_start}`" - f" and the same symbol at token `{first_pos}` which has" - f" {inner_tokens=}.\n{tokens=}" - ) - - if len(tokens) == 0: - raise ParseError("Recieved an empty strng") - - match tokens[0]: - case (symbol, Leaf()): - if len(tokens) > 1: - raise ParseError( - f"First token was symbol '{symbol}' which is" - f" a `Terminal`, but was proceeded by more token." - f"\n{tokens=}" - ) - _, root = tokens[0] - case (symbol, Grammar.NonTerminal(op=op)): - if op is None: - raise ParseError( - f"First token was symbol '{symbol}' which is" - f" a `NonTerminal` that is `passthrough`, i.e. it has no associated" - " operation and can not be the root." - ) - if len(tokens) < 4: - raise ParseError( - f"First token was symbol '{symbol}' which is" - f" a `NonTerminal`, but should have at least 3 more tokens" - " for a '(', 'child' and a closing ')'" - ) - - # NOTE: We don't care about shared here as we validate above that - # a shared variable can not contain itself, and there are no other - # symbols above or on the same level as this one (as it's the root). - # Hence we do not need to interact with `shared` here. - root = Container(symbol=symbol, children=[], op=op) - case "(" | ")" | ",": - raise ParseError("First token can not be a '(', ')' or a ','") - case rule: - assert_never(rule) - - if isinstance(root, Leaf): - return root - - variables: dict[str, Container | Passthrough] = {} - parent_stack: list[Container | Passthrough] = [] - current: Node = root - - token_stream = iter(tokens[1:]) - - for tok in token_stream: - match tok: - case ",": - parent_stack[-1].children.append(current) - case ")": - parent = parent_stack.pop() - parent.children.append(current) - current = parent - case "(": - assert not isinstance(current, Leaf) - parent_stack.append(current) - case (symbol, rule): - if isinstance(rule, Leaf): - current = rule - continue - - if rule.shared and (existing := variables.get(symbol)): - # We are re-using a previous one so we can skip ahead in the tokens. - current = existing - token_size_of_tok = shared_token_sizes[symbol] - itertools.islice(token_stream, token_size_of_tok) # Skips - continue - - if rule.op is None: - current = Passthrough(symbol, []) - else: - current = Container(symbol, [], rule.op) - - if rule.shared: - variables[symbol] = current - case _: - assert_never(tok) - - return current - - -# NOTE: Not sure we want this as a standalone function, but it serves to show some logic -def is_valid( - grammar: Grammar, - node: Node, - *, - already_shared: set[str] | None = None, -) -> bool: - rule = grammar.rules.get(node.symbol) - if rule is None: - raise ValueError( - f"Node has unknown symbol {node.symbol}, valid symbols are" - f" {grammar.rules.keys()}" - ) - - # We should never encounter a situtation where we have some nesting of shared nodes, - # for example, consider the following, where L2 is shared. - # L2 -> x -> ... -> L1 -> x -> ... - already_shared = already_shared or set() - if ( - isinstance(rule, Grammar.NonTerminal) - and rule.shared - and node.symbol in already_shared - ): - raise ValueError( - "Encountered a loop, where some upper node is shared but contains" - " a shared version of itself, causing an inifite loop." - ) - - match node: - case Leaf(symbol): - return symbol in grammar.rules - case Container(symbol, children, _) | Passthrough(symbol, children): - s = " ".join(child.symbol for child in children) - - match rule: - case Grammar.Terminal(_): - return s in grammar.rules and all( - is_valid(grammar, child, already_shared=already_shared.copy()) - for child in children - ) - case Grammar.NonTerminal(choices, _): - return s in choices and all( - is_valid(grammar, child, already_shared=already_shared.copy()) - for child in children - ) - case _: - assert_never(rule) - case _: - assert_never(node) - - -# TODO: Optimization, we don't need to recompute shared substrings. -# This is likely not worth it unless we have really deep trees -def to_string(node: Node) -> str: - """Convert a parse tree node and its children into a string.""" - match node: - case Leaf(symbol): - return symbol - case Passthrough(symbol, children) | Container(symbol, children): - return f"{symbol}({', '.join(to_string(c) for c in children)})" - case _: - assert_never(node) - - -# TODO: The variables thing can mess up the max depth -def bfs_grammar( # noqa: C901, D103 - grammar: Grammar, - symbol: str, - *, - max_depth: int, - current_depth: int = 0, - variables: dict[str, Node] | None = None, - rng_shuffle: np.random.Generator | None = None, -) -> Iterator[Node]: - if current_depth > max_depth: - return - - variables = variables or {} - shared_node = variables.get(symbol) - if shared_node is not None: - yield shared_node - return # TODO: check - - nxt_depth = current_depth + 1 - - rule = grammar.rules.get(symbol) - match rule: - case Grammar.Terminal(op=op): - node = Leaf(symbol, op) - yield node - case Grammar.NonTerminal(choices=choices, op=op): - for choice in choices: - children = choice.split(" ") - child_expansions: list[Iterator] = [ - bfs_grammar( - grammar, - child_symbol, - max_depth=max_depth, - current_depth=nxt_depth, - rng_shuffle=rng_shuffle, - variables=variables, - ) - for child_symbol in children - ] - - if rng_shuffle: - # This works correctly with python lists, but typing for numpy is off - rng_shuffle.shuffle(child_expansions) # type: ignore - - for possible in itertools.product(*child_expansions): - if op is None: - node = Passthrough(symbol, children=list(possible)) - else: - node = Container(symbol, op=op, children=list(possible)) - - if rule.shared: - variables[symbol] = node - - yield node - case None: - raise ValueError( - f"Could not find symbol {symbol} in table with keys{grammar.rules.keys()}" - ) - case _: - assert_never(rule) - - -def to_model(node: Node) -> Any: - """Convert a parse tree node and its children into some object it represents.""" - - def _build(_n: Node) -> list[Any] | Any: - match _n: - case Leaf(_, op): - return op() - case Container(_, children, op): - # The problem is that each child could be either: - # * A single 'thing', in the case of Leaf or Container - # * Multiple things, in case it's a passthrough - # Hence we flatten them out into a single big children itr - _l = [] - for child in children: - _b = _build(child) - if isinstance(_b, list): - _l.extend(_b) - continue - _l.append(_b) - - return op(*_l) - case Passthrough(_, children): - return [_build(child) for child in children] - case _: - assert_never(node) - - match node: - case Leaf() | Container(): - obj = _build(node) - assert not isinstance(obj, list) - return obj - case Passthrough(symbol): - raise ValueError(f"Can not call build on a `Passthrough` {symbol}") - case _: - assert_never(node) - - -grammar = { - "S": ( - ["C", "reluconvbn", "S", "S C", "O O O"], - nn.Sequential, - ), - "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), - "O": ["4", "1", "id"], - "reluconvbn": partial( - ReLUConvBN, in_channels=4, out_channels=3, kernel_size=3, stride=1, padding=1 - ), - "id": Identity, - "4": partial( - nn.Conv3d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 - ), - "2": partial( - nn.Conv3d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 - ), -} - - -# https://stackoverflow.com/a/29597210 -def hierarchy_pos( - G: nx.DiGraph, - root: int, - width: float = 2.0, - vert_gap: float = 1.2, - vert_loc: float = 1, - xcenter: float = 1.5, -) -> dict[int, tuple[float, float]]: - """From Joel's answer at https://stackoverflow.com/a/29597210/2966723. - Licensed under Creative Commons Attribution-Share Alike. - - If the graph is a tree this will return the positions to plot this in a - hierarchical layout. - - G: the graph (must be a tree) - - root: the root node of current branch - - if the tree is directed and this is not given, - the root will be found and used - - if the tree is directed and this is given, then - the positions will be just for the descendants of this node. - - if the tree is undirected and not given, - then a random choice will be used. - - width: horizontal space allocated for this branch - avoids overlap with other branches - - vert_gap: gap between levels of hierarchy - - vert_loc: vertical location of root - - xcenter: horizontal location of root - """ - if not nx.is_tree(G): - raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") - - def _hierarchy_pos( - G, - root, - width=2.0, - vert_gap=1.2, - vert_loc: float = 1, - xcenter=1.5, - pos: dict[int, tuple[float, float]] | None = None, - parent=None, - ) -> dict[int, tuple[float, float]]: - """See hierarchy_pos docstring for most arguments. - - pos: a dict saying where all nodes go if they have been assigned - parent: parent of this branch. - only affects it if non-directed - - """ - if pos is None: - pos = {root: (xcenter, vert_loc)} - else: - pos[root] = (xcenter, vert_loc) - children = list(G.neighbors(root)) - if not isinstance(G, nx.DiGraph) and parent is not None: - children.remove(parent) - if len(children) != 1: - dx = width / len(children) - nextx = xcenter - width / 3 - dx / 2 - for child in children: - nextx += dx - pos = _hierarchy_pos( - G, - child, - width=dx, - vert_gap=vert_gap, - vert_loc=vert_loc - vert_gap, - xcenter=nextx, - pos=pos, - parent=root, - ) - return pos - - return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) diff --git a/perf.py b/perf.py deleted file mode 100644 index 28f7716e2..000000000 --- a/perf.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import numpy as np -from graph import ( - Grammar, - Identity, - Node, - ReLUConvBN, - parse, - sample_grammar, - to_model, - to_nxgraph, - to_string, -) -from torch import nn - -structure = { - "S": ( - Grammar.NonTerminal( - ["C", "reluconvbn", "S", "S C", "O O O", "S S O O O O O O"], - nn.Sequential, - ) - ), - "C": (["O", "O S reluconvbn", "O S", "S"], nn.Sequential), - "O": ["3", "1", "id"], - "reluconvbn": partial( - ReLUConvBN, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 - ), - "id": Identity, - "3": partial( - nn.Conv2d, in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 - ), - "1": partial( - nn.Conv2d, in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 - ), -} - -if __name__ == "__main__": - import time - - grammar = Grammar.from_dict(structure) - rng = np.random.default_rng() - sample: Node = sample_grammar("S", grammar=grammar, rng=rng) - graph = to_nxgraph(sample) - model = to_model(sample) - - t0 = time.perf_counter() - samples = 10_000 - - for _ in range(samples): - sample: Node = sample_grammar("S", grammar=grammar, rng=rng) - string = to_string(sample) - node = parse(string=string, grammar=grammar) - # graph = to_nxgraph(sample) - # mutate_leaf_parents(root=sample, grammar=grammar, rng=rng) - # model = to_model(sample) - - t1 = time.perf_counter() - import rich - - rich.print(f"sampling takes {(t1 - t0) / samples}s on average over {samples} samples") - rich.print(f"duration for {samples} samples: {t1 - t0}s ") diff --git a/t.py b/t.py deleted file mode 100644 index 5fbd25dbb..000000000 --- a/t.py +++ /dev/null @@ -1,13 +0,0 @@ - -import rich -import neps - -space = neps.SearchSpace( - { - "a": neps.Integer(0, 10), - "b": neps.Categorical(["a", "b", "c"]), - "c": neps.Float(1e-5, 1e0, log=True, prior=1e-3), - } - ) - -rich.print(space) diff --git a/test_graph.py b/test_graph.py deleted file mode 100644 index 09759a51b..000000000 --- a/test_graph.py +++ /dev/null @@ -1,434 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - -import pytest -from graph import ( - Container, - Grammar, - Leaf, - Node, - ParseError, - Passthrough, - bfs_node, - dfs_node, - parse, - select, - to_model, - to_node_from_graph, - to_nxgraph, - to_string, -) - - -# Leafs -@dataclass -class T: - s: str - - # This is the `op()` - def __call__(self) -> str: - return self.s - - -def join(*s: str) -> str: - return "[" + "".join(s) + "]" - - -grammar_1 = Grammar.from_dict( - { - "s": (["a", "b", "p", "p p"], join), - "p": ["a b", "s"], - "a": T("a"), - "b": T("b"), - } -) - -grammar_2 = Grammar.from_dict( - { - "L1": (["L2 L2 L3"], join), - "L2": Grammar.NonTerminal(["L3"], join, shared=True), - "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), - "a": T("a"), - "b": T("a"), - } -) - - -@pytest.mark.parametrize( - ("grammar", "string", "built", "node"), - [ - (grammar_1, "a", "a", Leaf("a", T("a"))), - (grammar_1, "b", "b", Leaf("b", T("b"))), - ( - grammar_1, - "s(a)", - "[a]", - Container("s", op=join, children=[Leaf("a", T("a"))]), - ), - ( - grammar_1, - "s(p(a, b))", - "[ab]", - Container( - "s", - children=[ - Passthrough( - "p", - children=[Leaf("a", T("a")), Leaf("b", T("b"))], - ), - ], - op=join, - ), - ), - ( - grammar_1, - "s(p(a, b), p(s(a)))", - "[ab[a]]", - Container( - "s", - children=[ - Passthrough( - "p", - children=[Leaf("a", T("a")), Leaf("b", T("b"))], - ), - Passthrough( - "p", - children=[Container("s", children=[Leaf("a", T("a"))], op=join)], - ), - ], - op=join, - ), - ), - ( - grammar_1, - "s(p(s(a)))", - "[[a]]", - Container( - "s", - children=[ - Passthrough( - "p", - children=[ - Container( - "s", - children=[Leaf("a", T("a"))], - op=join, - ) - ], - ), - ], - op=join, - ), - ), - ], -) -def test_string_serialization_and_deserialization_correct( - grammar: Grammar, - string: str, - built: str, - node: Node, -) -> None: - # Test parsing - parsed = parse(grammar, string) - assert parsed == node - - # Test serialization - serialized_again = to_string(parsed) - assert serialized_again == string - - # Test building - assert to_model(parsed) == built - - # Test graph and back again - graph = to_nxgraph(parsed, include_passthroughs=True) - node_again = to_node_from_graph(graph, grammar) - assert parsed == node_again - - -@pytest.mark.parametrize( - ("grammar", "string"), - [ - (grammar_1, "c"), - (grammar_1, ""), - (grammar_1, "s(a"), - (grammar_1, "p(a, b)"), - (grammar_1, "("), - (grammar_1, "s(a))"), - (grammar_1, "s((a)"), - (grammar_1, "s("), - (grammar_1, "s)"), - (grammar_1, "a, a"), - (grammar_1, "a,"), - (grammar_1, "s, s"), - # Invalid due to shared rule but not sharing values - (grammar_2, "L1(L2(L3(a)), L2(L3(a)), L3(b))"), - ], -) -def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: - with pytest.raises(ParseError): - parse(grammar, string) - - -def test_dfs_node_container() -> None: - node = Container( - "s", - children=[ - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - ], - op=join, - ) - outcome = list(dfs_node(node)) - expected = [ - # First - Container( - "s", - children=[ - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - ], - op=join, - ), - # go down left depth first - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Leaf("a_left", T("a")), - Leaf("b_left", T("b")), - # go down right depth first - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - Leaf("a_right", T("a")), - Leaf("b_right", T("b")), - ] - for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): - assert e == o, f"Failed at index {i}" - - -def test_bfs_node_container() -> None: - node = Container( - "s", - children=[ - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - ], - op=join, - ) - outcome = list(bfs_node(node)) - expected = [ - # First - Container( - "s", - children=[ - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - ], - op=join, - ), - # Second level first - Container( - "s_left", - children=[Leaf("a_left", T("a")), Leaf("b_left", T("b"))], - op=join, - ), - Container( - "s_right", - children=[Leaf("a_right", T("a")), Leaf("b_right", T("b"))], - op=join, - ), - # Then 3rd level - Leaf("a_left", T("a")), - Leaf("b_left", T("b")), - Leaf("a_right", T("a")), - Leaf("b_right", T("b")), - ] - for i, (e, o) in enumerate(zip(expected, outcome, strict=True)): - assert e == o, f"Failed at index {i}" - - -def test_select_symbol() -> None: - root = Container( - "a", - children=[ - Container( - "b", - children=[ - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - ], - op=join, - ), - Container("c", children=[Leaf("l2", op=T("l2"))], op=join), - Leaf("l3", op=T("l3")), - Container( - "d", - children=[Leaf("l4", op=T("l4"))], - op=join, - ), - ], - op=join, - ) - selected = list(select(root, how=("symbol", "d"))) - assert selected == [ - Container( - "d", - children=[Leaf("l4", op=T("l4"))], - op=join, - ), - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - ] - - -def test_select_depth() -> None: - root = Container( - "a", - children=[ - Container( - "b", - children=[ - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - ], - op=join, - ), - Container("c", children=[Leaf("l2", op=T("l2"))], op=join), - Leaf("l3", op=T("l3")), - Container( - "d", - children=[Leaf("l4", op=T("l4"))], - op=join, - ), - ], - op=join, - ) - selected = list(select(root, how=("depth", 1))) - assert selected == root.children - - selected = list(select(root, how=("depth", range(1, 3)))) - expected = [ - # Depth 1 - *root.children, - # Depth 2 - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - Leaf("l2", op=T("l2")), - Leaf("l4", op=T("l4")), - ] - assert selected == expected - - -def test_select_climb() -> None: - # NOTE: The order is rather arbitrary and not much thought has been given to it. - # However the test still tests a particular order that was done by trial and - # error. Feel free to redo the order if this changes. - root = Container( - "a", - children=[ - Container( - "b", - children=[ - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - ], - op=join, - ), - Container("c", children=[Leaf("l2", op=T("l2"))], op=join), - Leaf("l3", op=T("l3")), - Container( - "d", - children=[Leaf("l4", op=T("l4"))], - op=join, - ), - ], - op=join, - ) - selected = list(select(root, how=("climb", 0))) - assert selected == [ - Leaf("l3", op=T("l3")), - Leaf("l2", op=T("l2")), - Leaf("l4", op=T("l4")), - Leaf("l1", op=T("l1")), - ] - - selected = list(select(root, how=("climb", range(1, 3)))) - expected = [ - root, - Container("c", children=[Leaf("l2", op=T("l2"))], op=join), - Container( - "d", - children=[Leaf("l4", op=T("l4"))], - op=join, - ), - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - Container( - "b", - children=[ - Container( - "d", - children=[Leaf("l1", op=T("l1"))], - op=join, - ), - ], - op=join, - ), - ] - for i, (sel, exp) in enumerate(zip(selected, expected, strict=True)): - assert sel == exp, f"Mismatch at pos {i}:\nExpected: {exp}\n\nGot: {sel}" From 4479b00e3773dfbb4ea0113392114785156833ef Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 20 Feb 2025 17:42:04 +0100 Subject: [PATCH 44/50] fix(grammar): to_model handles deep nested passthroughs --- neps/space/grammar.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/neps/space/grammar.py b/neps/space/grammar.py index d3064847f..0715b7a7a 100644 --- a/neps/space/grammar.py +++ b/neps/space/grammar.py @@ -95,6 +95,7 @@ from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never +import more_itertools import networkx as nx import numpy as np @@ -230,15 +231,11 @@ def _build(_n: Node) -> list[Any] | Any: # * A single 'thing', in the case of Leaf or Container # * Multiple things, in case it's a passthrough # Hence we flatten them out into a single big children itr - _l = [] - for child in children: - _b = _build(child) - if isinstance(_b, list): - _l.extend(_b) - continue - _l.append(_b) - - return op(*_l) + built_children = more_itertools.collapse( + (_build(child) for child in children), + base_type=(op if isinstance(op, type) else None), + ) + return op(*built_children) case Passthrough(_, children): return [_build(child) for child in children] case _: From e96e1c4f65a18ecefff1f8d34a79bc956b4660d7 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 20 Feb 2025 17:50:02 +0100 Subject: [PATCH 45/50] style: typing fixes --- neps/optimizers/priorband.py | 9 +++++++-- neps/optimizers/random_search.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/neps/optimizers/priorband.py b/neps/optimizers/priorband.py index 9d6d23e4b..eacd46181 100644 --- a/neps/optimizers/priorband.py +++ b/neps/optimizers/priorband.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, assert_never import numpy as np import torch @@ -103,13 +103,18 @@ def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: or spent_one_sh_bracket_worth_of_fidelity is False or any_rung_with_eta_evals is False ): - policy = np.random.choice(["prior", "random"], p=[w_prior, w_random]) + policy: Literal["prior", "random"] = np.random.choice( + ["prior", "random"], + p=[w_prior, w_random], + ) match policy: case "prior": config = prior_dist.sample_config(to=self.encoder) case "random": _sampler = Sampler.uniform(ndim=self.encoder.ndim) config = _sampler.sample_config(to=self.encoder) + case _: + assert_never(policy) return config diff --git a/neps/optimizers/random_search.py b/neps/optimizers/random_search.py index fab67e546..44d11bf90 100644 --- a/neps/optimizers/random_search.py +++ b/neps/optimizers/random_search.py @@ -30,7 +30,7 @@ def __call__( ) -> SampledConfig | list[SampledConfig]: n_trials = len(trials) _n = 1 if n is None else n - configs_tensor: Tensor = self.numerical_sampler.sample(_n, to=self.encoder) + configs_tensor = self.numerical_sampler.sample(_n, to=self.encoder) config_dicts = self.encoder.decode(configs_tensor) for config in config_dicts: From c57398fab2518cb282b23ac9ee6545d287bdb452 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 20 Feb 2025 17:54:03 +0100 Subject: [PATCH 46/50] fix: raise explicit grammar not supported for optimizers --- neps/optimizers/algorithms.py | 3 +++ neps/optimizers/bracket_optimizer.py | 6 ++++++ neps/optimizers/ifbo.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/neps/optimizers/algorithms.py b/neps/optimizers/algorithms.py index dafad4f75..6be6ef2b2 100644 --- a/neps/optimizers/algorithms.py +++ b/neps/optimizers/algorithms.py @@ -384,6 +384,9 @@ def grid_search(pipeline_space: SearchSpace) -> GridSearch: """ from neps.optimizers.utils.grid import make_grid + if pipeline_space.grammar is not None: + raise NotImplementedError("Grammars not supported for `grid_search` yet.") + return GridSearch(configs_list=make_grid(pipeline_space)) diff --git a/neps/optimizers/bracket_optimizer.py b/neps/optimizers/bracket_optimizer.py index cea6624d9..5279e0b68 100644 --- a/neps/optimizers/bracket_optimizer.py +++ b/neps/optimizers/bracket_optimizer.py @@ -249,6 +249,12 @@ class BracketOptimizer: fid_name: str """The name of the fidelity in the space.""" + def __post_init__(self) -> None: + if self.space.grammar is not None: + raise NotImplementedError( + "Grammars not supported for `BracketOptimizer` yet." + ) + def __call__( # noqa: C901, PLR0912 self, trials: Mapping[str, Trial], diff --git a/neps/optimizers/ifbo.py b/neps/optimizers/ifbo.py index 58671bcfd..416cd66ab 100755 --- a/neps/optimizers/ifbo.py +++ b/neps/optimizers/ifbo.py @@ -129,6 +129,10 @@ class IFBO: Each one will be treated as an individual fidelity level. """ + def __post_init__(self) -> None: + if self.space.grammar is not None: + raise NotImplementedError("Grammars not supported for `IFBO` yet.") + def __call__( self, trials: Mapping[str, Trial], From 7f5165d04d73fbee6a175d7f8aa8770f10f9c8b9 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 18:04:18 +0100 Subject: [PATCH 47/50] fix: graph tests --- tests/test_graph.py | 74 ++++++++++++--------------------------------- 1 file changed, 19 insertions(+), 55 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index d877ed47d..927ef7ee2 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -8,7 +8,9 @@ import numpy as np import pytest import torch -from graph import ( +from torch import nn + +from neps.space.grammar import ( Container, Grammar, Leaf, @@ -17,16 +19,11 @@ Passthrough, bfs_node, dfs_node, - mutations, - parse, - sample_grammar, select, to_model, - to_node_from_graph, to_nxgraph, to_string, ) -from torch import nn # Leafs @@ -44,26 +41,29 @@ def join(*s: str) -> str: grammar_1 = Grammar.from_dict( - { + start_symbol="s", + grammar={ "s": (["a", "b", "p", "p p"], join), "p": ["a b", "s"], "a": T("a"), "b": T("b"), - } + }, ) grammar_2 = Grammar.from_dict( - { + start_symbol="L1", + grammar={ "L1": (["L2 L2 L3"], join), "L2": Grammar.NonTerminal(["L3"], join, shared=True), "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), "a": T("a"), "b": T("a"), - } + }, ) grammar_3 = Grammar.from_dict( - { + start_symbol="S", + grammar={ "S": (["mlp", "O"], nn.Sequential), "mlp": (["L", "O", "S O"], nn.Sequential), "L": ( @@ -75,7 +75,7 @@ def join(*s: str) -> str: "linear128": partial(nn.LazyLinear, out_features=64), "relu": nn.ReLU, "elu": nn.ELU, - } + }, ) @@ -154,7 +154,7 @@ def test_string_serialization_and_deserialization_correct( node: Node, ) -> None: # Test parsing - parsed = parse(grammar, string) + parsed = grammar.parse(string) assert parsed == node # Test serialization @@ -166,7 +166,8 @@ def test_string_serialization_and_deserialization_correct( # Test graph and back again graph = to_nxgraph(parsed, include_passthroughs=True) - node_again = to_node_from_graph(graph, grammar) + + node_again = grammar.node_from_graph(graph) assert parsed == node_again @@ -191,7 +192,7 @@ def test_string_serialization_and_deserialization_correct( ) def test_string_deserialization_fail_cases(grammar: Grammar, string: str) -> None: with pytest.raises(ParseError): - parse(grammar, string) + grammar.parse(string) def test_dfs_node_container() -> None: @@ -467,7 +468,7 @@ def test_sample_grammar_and_build_model(grammar: Grammar): t0 = time.perf_counter() samples = 1_000 for _ in range(samples): - sample: Node = sample_grammar("S", grammar=grammar, rng=rng) + sample: Node = grammar.sample("S", rng=rng) model: nn.Module = to_model(sample) model(x) assert sum(p.numel() for p in model.parameters()) > 0 @@ -500,10 +501,9 @@ def test_sample_grammar_and_mutate( time.perf_counter() samples = 1_000 for _ in range(samples): - sample: Node = sample_grammar("S", grammar=grammar, rng=rng) - muts = mutations( + sample: Node = grammar.sample("S", rng=rng) + muts = grammar.mutations( root=sample, - grammar=grammar, which=select(root=sample, how=how), max_mutation_depth=3, ) @@ -514,39 +514,3 @@ def test_sample_grammar_and_mutate( model: nn.Module = to_model(_mut) model(x) assert sum(p.numel() for p in model.parameters()) > 0 - - -grammar_1 = Grammar.from_dict( - { - "s": (["a", "b", "p", "p p"], join), - "p": ["a b", "s"], - "a": T("a"), - "b": T("b"), - } -) - -grammar_2 = Grammar.from_dict( - { - "L1": (["L2 L2 L3"], join), - "L2": Grammar.NonTerminal(["L3"], join, shared=True), - "L3": Grammar.NonTerminal(["a", "b"], None, shared=True), - "a": T("a"), - "b": T("a"), - } -) - -grammar_3 = Grammar.from_dict( - { - "S": (["mlp", "O"], nn.Sequential), - "mlp": (["L", "O", "S O"], nn.Sequential), - "L": ( - ["linear64 linear128 relu O linear64 relu O", "linear64 elu linear64"], - nn.Sequential, - ), - "O": (["linear64", "linear64 relu", "linear128 elu"], nn.Sequential), - "linear64": partial(nn.LazyLinear, out_features=64), - "linear128": partial(nn.LazyLinear, out_features=64), - "relu": nn.ReLU, - "elu": nn.ELU, - } -) From 98911f5fae29f94059cd3ac1a8f08d68859ef574 Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 18:22:06 +0100 Subject: [PATCH 48/50] chore: deleting unused cli --- neps/cli/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 neps/cli/__init__.py diff --git a/neps/cli/__init__.py b/neps/cli/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 4c4ab0131c83a8528a5bccff9a471d7bc2b1fbec Mon Sep 17 00:00:00 2001 From: timurcarstensen Date: Thu, 20 Feb 2025 18:25:26 +0100 Subject: [PATCH 49/50] fix: add comment to tests --- tests/test_graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index 927ef7ee2..198d517de 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -473,6 +473,7 @@ def test_sample_grammar_and_build_model(grammar: Grammar): model(x) assert sum(p.numel() for p in model.parameters()) > 0 + # feel free to increase the time limit here, based on running this on a M4 Mac assert time.perf_counter() - t0 < 1 From 4ba9ab3f54e7a0b3cc57026c23a2ce7354264ed1 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 21 Feb 2025 14:25:37 +0100 Subject: [PATCH 50/50] tmp --- neps/optimizers/bayesian_optimization.py | 6 +- neps/optimizers/utils/initial_design.py | 65 +++++++--- neps/space/grammar.py | 158 ++++++++++++++++++++++- 3 files changed, 211 insertions(+), 18 deletions(-) diff --git a/neps/optimizers/bayesian_optimization.py b/neps/optimizers/bayesian_optimization.py index b019136ac..ac4dd1b59 100644 --- a/neps/optimizers/bayesian_optimization.py +++ b/neps/optimizers/bayesian_optimization.py @@ -86,7 +86,11 @@ def __call__( n: int | None = None, ) -> SampledConfig | list[SampledConfig]: assert self.space.fidelity is None, "Fidelity not supported yet." - parameters = {**self.space.numerical, **self.space.categoricals} + parameters = { + **self.space.numerical, + **self.space.categoricals, + **self.space.grammars, + } n_to_sample = 1 if n is None else n n_sampled = len(trials) diff --git a/neps/optimizers/utils/initial_design.py b/neps/optimizers/utils/initial_design.py index 615a5a257..6f55fa267 100644 --- a/neps/optimizers/utils/initial_design.py +++ b/neps/optimizers/utils/initial_design.py @@ -5,20 +5,28 @@ import torch +from neps.optimizers.priorband import mutate_config from neps.sampling import Prior, Sampler +from neps.space import Grammar +from neps.space.grammar import RandomSampler, MutationSampler, GrammarSampler if TYPE_CHECKING: - from neps.space import ConfigEncoder - from neps.space.parameters import Parameter + from neps.space import ConfigEncoder, Parameter def make_initial_design( *, - parameters: Mapping[str, Parameter], + parameters: Mapping[str, Parameter | Grammar], encoder: ConfigEncoder, sampler: Literal["sobol", "prior", "uniform"] | Sampler, sample_size: int | Literal["ndim"] | None = "ndim", sample_prior_first: bool = True, + grammar_mutant_selector:( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) = ("climb", range(1, 4)), + grammar_max_mutation_depth: int = 3, seed: torch.Generator | None = None, ) -> list[dict[str, Any]]: """Generate the initial design of the optimization process. @@ -41,37 +49,64 @@ def make_initial_design( If None, no configurations will be sampled. sample_prior_first: Whether to sample the prior configuration first. + grammar_mutant_selector: Please see [`select()`][neps.space.grammar.select]. + grammar_max_mutation_depth: How deep to enumerate mutants of a prior for the + grammar. seed: The seed to use for the random number generation. """ configs: list[dict[str, Any]] = [] + numerics = {k: p for k, p in parameters.items() if not isinstance(p, Grammar)} + grammars = {k: p for k, p in parameters.items() if isinstance(p, Grammar)} if sample_prior_first: - configs.append( - { - name: p.prior if p.prior is not None else p.center - for name, p in parameters.items() - } - ) - - ndims = len(parameters) + grammar_priors: dict[str, str] = { + k: ( + g.prior + if g.prior is not None + # Ew sorry + else RandomSampler(g).sample(1)[0].to_string() + ) + for k, g in grammars.items() + } + numeric_priors: dict[str, Any] = { + name: p.prior if p.prior is not None else p.center + for name, p in numerics.items() + } + configs.append({**numeric_priors, **grammar_priors}) + + numeric_ndims = len(numerics) + grammar_expansion_count = sum(g._expansion_count for g in grammars.values()) if sample_size == "ndim": - sample_size = ndims + # TODO: Not sure how to handle graphs here properly here to be honest + sample_size = numeric_ndims + grammar_expansion_count elif sample_size is not None and not sample_size > 0: raise ValueError( "The sample size should be a positive integer if passing an int." ) if sample_size is not None: + # Numeric sampling match sampler: case "sobol": - sampler = Sampler.sobol(ndim=ndims) + numeric_sampler = Sampler.sobol(ndim=numeric_ndims) + grammar_sampler = GrammarSampler.random(grammars) case "uniform": - sampler = Sampler.uniform(ndim=ndims) + numeric_sampler = Sampler.uniform(ndim=numeric_ndims) + grammar_sampler = GrammarSampler.random(grammars) case "prior": - sampler = Prior.from_parameters(parameters) + numeric_sampler = Prior.from_parameters(numerics) + grammar_sampler = GrammarSampler.prior( + grammars, + mutant_selector=grammar_mutant_selector, + max_mutation_depth=grammar_max_mutation_depth + ) case _: pass + # TODO: Replace with something more solid + # Grammar sampling + for k, g in grammars.items(): + encoded_configs = sampler.sample(sample_size * 2, to=encoder.domains, seed=seed) uniq_x = torch.unique(encoded_configs, dim=0) sample_configs = encoder.decode(uniq_x[:sample_size]) diff --git a/neps/space/grammar.py b/neps/space/grammar.py index 0715b7a7a..df90d8b0f 100644 --- a/neps/space/grammar.py +++ b/neps/space/grammar.py @@ -90,7 +90,7 @@ import itertools from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Mapping from dataclasses import dataclass, field from typing import Any, ClassVar, Literal, NamedTuple, TypeAlias from typing_extensions import assert_never @@ -504,12 +504,17 @@ class Grammar: Args: start_symbol: The starting symbol used by optimizers. rules: The possible grammar rules which define the structure of the grammar. + prior: Some prior string producable from the grammar that can be used as a user + prior. """ start_symbol: str rules: dict[str, Terminal | NonTerminal] + prior: str | None = None + _expansion_count: int = field(init=False) _shared: dict[str, NonTerminal] = field(init=False) _leafs: dict[str, Leaf] = field(init=False) + _prior_node: Node | None = field(init=False) class Terminal(NamedTuple): """A symbol which has no children and an associated operation. @@ -570,6 +575,34 @@ def __post_init__(self) -> None: if isinstance(r, Grammar.Terminal) } + # In lue of a good proxy for 'size', which we might need in some scenarios, + # such as in initial design where you can specify sample size is of size + # `"ndim"`, we use this proxy. Totally unscientific. + # Things to consider if you want to change it: + # * Recursive elements (e.g. A -> [A, a, b, c]), where recursion by uniform + # sampling is 1/4. This would need to propogate if for example `b` could also + # recurse on itself.... + # * That we can have multiple children, i.e. `A -> [A a, A b, A c, A A]` + # * Leafs do not expand the size + self._expansion_count = sum( + len(rule.choices) + for rule in self.rules.values() + if isinstance(rule, Grammar.NonTerminal) + ) + + if self.prior is not None: + try: + prior_node = self.parse(self.prior) + except ParseError as e: + raise ValueError( + f"The prior '{self.prior}' given for this grammar could" + " not be parsed properly." + ) from e + else: + prior_node = None + + self._prior_node = prior_node + @classmethod def from_dict( cls, @@ -582,6 +615,8 @@ def from_dict( | Grammar.Terminal | Grammar.NonTerminal, ], + *, + prior: str | None = None, ) -> Grammar: """Create a `Grammar` from a dictionary. @@ -590,6 +625,8 @@ def from_dict( Args: start_symbol: The starting symbol from which to produce strings. grammar: The rules of the grammar. + prior: Some prior string producable from the grammar that can be used as a + user prior. """ rules: dict[str, Grammar.Terminal | Grammar.NonTerminal] = {} for symbol, rule in grammar.items(): @@ -624,7 +661,7 @@ def from_dict( f"\n Got {rule}" ) - return Grammar(start_symbol=start_symbol, rules=rules) + return Grammar(start_symbol=start_symbol, rules=rules, prior=prior) def sample( # noqa: C901, PLR0912 self, @@ -1284,3 +1321,120 @@ def _hierarchy_pos( return pos return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + + +# TODO: Everything below this point should probably be moved. + + +@dataclass +class RandomSampler: + grammar: Grammar + + def sample( + self, + n: int, + *, + rng: np.random.Generator | _BufferedRandInts | None = None, + ) -> list[Node]: + match rng: + case None: + rng = _BufferedRandInts(rng=np.random.default_rng()) + case np.random.Generator(): + rng = _BufferedRandInts(rng=rng) + case _BufferedRandInts(): + pass + case _: + assert_never(rng) + + return [self.grammar.sample(rng=rng) for _ in range(n)] + + +@dataclass +class MutationSampler: + grammar: Grammar + ref_point: Node + max_mutation_depth: int + mutant_selector: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) + + def sample(self, n: int, *, rng: np.random.Generator | None = None) -> list[Node]: + if rng is None: + rng = np.random.default_rng() + + nodes_to_mutate_from = self.ref_point.select(how=self.mutant_selector) + all_possible_mutants = self.grammar.mutations( + self.ref_point, + which=nodes_to_mutate_from, + max_mutation_depth=self.max_mutation_depth, + ) + all_possible_mutants = list(all_possible_mutants) + return rng.choice(all_possible_mutants, size=n, replace=False) # type: ignore + + +@dataclass +class GrammarSampler: + samplers: Mapping[str, RandomSampler | MutationSampler] + + def sample( + self, n: int, *, rng: np.random.Generator | None = None + ) -> list[dict[str, Node]]: + """Sample n dictionaries of nodes from the underlying grammar samplers. + + Args: + n: the number of samples to generate. + rng: the random number generator to use. + + Returns: + A list of dictionaries mapping each sampler's key to a sampled Node. + """ + if rng is None: + rng = np.random.default_rng() + + samples: dict[str, list[Node]] = { + k: sampler.sample(n, rng=rng) for k, sampler in self.samplers.items() + } + return [{k: samples[k][i] for k in samples} for i in range(n)] + + @classmethod + def random(cls, grammars: Mapping[str, Grammar]) -> GrammarSampler: + return cls(samplers={k: RandomSampler(g) for k, g in grammars.items()}) + + @classmethod + def prior( + cls, + grammars: Mapping[str, Grammar], + *, + mutant_selector: ( + tuple[Literal["symbol"], str] + | tuple[Literal["depth"], int | range] + | tuple[Literal["climb"], int | range] + ) = ("climb", range(1, 3)), + max_mutation_depth: int = 3, + ) -> GrammarSampler: + """Creates samplers for the grammars, using the prior where possible. + + Grammars without a prior use the `RandomSampler` while those with a prior + have mutations done around the prior. + + Args: + grammars: the grammars to build samplers for. + mutant_selector: Please take a look at [`select()`][neps.space.grammar.select] + max_mutation_depth: Dictates how deep mutations of grammars can go to prevent + overly large configurations due to recursive rules. + """ + samplers: dict[str, RandomSampler | MutationSampler] = {} + for k, g in grammars.items(): + if g._prior_node is not None: + samplers[k] = MutationSampler( + g, + ref_point=g._prior_node, + max_mutation_depth=max_mutation_depth, + mutant_selector=mutant_selector, + ) + else: + samplers[k] = RandomSampler(g) + + return cls(samplers)