|
| 1 | +import sys; sys.path.append('../') |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import matplotlib as mpl |
| 5 | +import torch |
| 6 | +import math |
| 7 | +import random |
| 8 | +import numpy as np |
| 9 | +import scipy.io |
| 10 | +from scipy.interpolate import interp1d |
| 11 | + |
| 12 | +from botorch.models import SingleTaskGP, FixedNoiseGP |
| 13 | +from botorch.fit import fit_gpytorch_model |
| 14 | +from gpytorch.likelihoods import GaussianLikelihood |
| 15 | +from gpytorch.constraints import Interval, GreaterThan |
| 16 | +from gpytorch.mlls import ExactMarginalLogLikelihood |
| 17 | +from gpytorch.kernels import CosineKernel, RBFKernel, RQKernel, MaternKernel |
| 18 | + |
| 19 | +# Top-K specific dependencies |
| 20 | +from src.alg.algorithms import MyTopK |
| 21 | +from src.utils.domain_util import unif_random_sample_domain |
| 22 | +from src.utils.misc_util import dict_to_namespace |
| 23 | +from argparse import Namespace, ArgumentParser |
| 24 | +from src.acq.acquisition import MyBaxAcqFunction |
| 25 | +from src.acq.acqoptimize import AcqOptimizer |
| 26 | + |
| 27 | + |
| 28 | +nice_fonts = { |
| 29 | + # Use LaTex to write all text |
| 30 | + "text.usetex": False, |
| 31 | + "font.family": "serif", |
| 32 | + "mathtext.fontset": "dejavuserif", |
| 33 | + # Thesis use 16 and 14, respectively |
| 34 | + "axes.labelsize": 18, |
| 35 | + "font.size": 18, |
| 36 | + # Make the legend/label fonts a little smaller |
| 37 | + "legend.fontsize": 16, |
| 38 | + "xtick.labelsize": 18, |
| 39 | + "ytick.labelsize": 18, |
| 40 | + "figure.figsize": [12,8], |
| 41 | +} |
| 42 | + |
| 43 | +mpl.rcParams.update(nice_fonts) |
| 44 | + |
| 45 | +# device = torch.device("cpu") |
| 46 | +dtype = torch.double |
| 47 | + |
| 48 | +# Set function |
| 49 | +def f(x): |
| 50 | + |
| 51 | + return torch.sin(8*x)/x.abs()**0.5 |
| 52 | + |
| 53 | +# Interpolation for torch dependencies to maintain grad operations |
| 54 | +# From https://gist.github.com/chausies |
| 55 | +def interp_func(x, y): |
| 56 | + "Returns integral of interpolating function" |
| 57 | + if len(y)>1: |
| 58 | + m = (y[1:] - y[:-1])/(x[1:] - x[:-1]) |
| 59 | + m = torch.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) |
| 60 | + def f(xs): |
| 61 | + if len(y)==1: # in the case of 1 point, treat as constant function |
| 62 | + return y[0] + torch.zeros_like(xs) |
| 63 | + I = torch.searchsorted(x[1:], xs) |
| 64 | + dx = (x[I+1]-x[I]) |
| 65 | + hh = h_poly((xs-x[I])/dx) |
| 66 | + return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx |
| 67 | + |
| 68 | + return f |
| 69 | + |
| 70 | +def interp(x, y, xs): |
| 71 | + return interp_func(x,y)(xs) |
| 72 | + |
| 73 | +def h_poly_helper(tt): |
| 74 | + A = torch.tensor([ |
| 75 | + [1, 0, -3, 2], |
| 76 | + [0, 1, -2, 1], |
| 77 | + [0, 0, 3, -2], |
| 78 | + [0, 0, -1, 1] |
| 79 | + ], dtype=tt[-1].dtype) |
| 80 | + return [ |
| 81 | + sum( A[i, j]*tt[j] for j in range(4) ) |
| 82 | + for i in range(4) ] |
| 83 | + |
| 84 | +def h_poly(t): |
| 85 | + tt = [ None for _ in range(4) ] |
| 86 | + tt[0] = 1 |
| 87 | + for i in range(1, 4): |
| 88 | + tt[i] = tt[i-1]*t |
| 89 | + return h_poly_helper(tt) |
| 90 | + |
| 91 | +# Different top-k flavors |
| 92 | +def naive_topk(xvals, yvals, k): |
| 93 | + """ |
| 94 | + Function to return the top-k local maxima. |
| 95 | + """ |
| 96 | + glob_indices = np.argsort(-yvals).cpu().numpy()[:k] |
| 97 | + |
| 98 | + return xvals[glob_indices], yvals[glob_indices] |
| 99 | + |
| 100 | +def naive_topk_eps(xvals, yvals, k, eps): |
| 101 | + """ |
| 102 | + Function to return the top-k local maxima that are eps apart from each other. |
| 103 | + """ |
| 104 | + glob_indices = np.argsort(-yvals).cpu().numpy() |
| 105 | + |
| 106 | + topk_indices = [glob_indices[0]] |
| 107 | + |
| 108 | + i = 1 |
| 109 | + while i < len(glob_indices) and len(topk_indices) < k: |
| 110 | + temp = np.zeros(len(topk_indices)) |
| 111 | + for ii in range(len(topk_indices)): |
| 112 | + temp[ii] = abs(xvals[glob_indices[i]] - xvals[[topk_indices[ii]]]) |
| 113 | + |
| 114 | + if ((temp >= eps).sum() == temp.size).astype(np.int): |
| 115 | + topk_indices.append(glob_indices[i]) |
| 116 | + i += 1 |
| 117 | + |
| 118 | + return xvals[topk_indices], yvals[topk_indices] |
| 119 | + |
| 120 | +def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iters=1000, eps=1e-1): |
| 121 | + """ |
| 122 | + Function to return the top-k local maxima using a particle gradient descent approach. |
| 123 | + The algorithm will randomly summon n_particles within your domain, calculate the gradient, |
| 124 | + check the values, and only retain the peaks. |
| 125 | + """ |
| 126 | + dom_min, dom_max = xvals.min(), xvals.max() |
| 127 | + |
| 128 | + # sample `n_particles` in uniform [dom_min, dom_max] |
| 129 | + particles = (dom_max - dom_min) * torch.rand(n_particles) + dom_min |
| 130 | + |
| 131 | + particles = particles.requires_grad_(True) |
| 132 | + |
| 133 | + for i in range(inner_iters): |
| 134 | + old_particles = particles |
| 135 | + y = func(xvals, yvals, particles) |
| 136 | + grad = torch.autograd.grad(y.sum(), particles, retain_graph=True)[0] |
| 137 | + particles = particles + lr*grad |
| 138 | + particles = torch.clamp(particles, min=dom_min, max=dom_max) |
| 139 | + if (particles - old_particles).abs().max() < eps/10: |
| 140 | + # print("Particles are not changing by much, exiting inner loop...") |
| 141 | + # print("Algorithm took", i, "iterations to converge within", eps/10) |
| 142 | + # print("") |
| 143 | + break |
| 144 | + |
| 145 | + y = func(xvals, yvals, particles) |
| 146 | + |
| 147 | + # pick top-k |
| 148 | + x_ = torch.zeros(1) |
| 149 | + |
| 150 | + vals, indices = y.sort(descending=True) |
| 151 | + values = [] |
| 152 | + values_x = [] |
| 153 | + |
| 154 | + for i in range(k): |
| 155 | + if len(particles) == 0: |
| 156 | + print("Breaking early, not enough significant top-k elements...") |
| 157 | + break |
| 158 | + |
| 159 | + v = 0 |
| 160 | + while True: |
| 161 | + if v >= len(particles) or v >= len(vals) or v >= len(indices): |
| 162 | + print("Breaking early, not enough significant top-k elements...") |
| 163 | + return torch.stack(values), torch.stack(values_x) |
| 164 | + |
| 165 | + candidate = vals[v] |
| 166 | + candidate_x = particles[indices[v]] |
| 167 | + v = v + 1 |
| 168 | + if (candidate_x - x_).abs() >= eps: break |
| 169 | + |
| 170 | + values.append(candidate) |
| 171 | + values_x.append(candidate_x) |
| 172 | + |
| 173 | + # pop chosen element |
| 174 | + particles = torch.cat([particles[:v], particles[v+1:]]) |
| 175 | + indices = torch.cat([indices[:v], indices[v+1:]]) |
| 176 | + vals = torch.cat([vals[:v], vals[v+1:]]) |
| 177 | + |
| 178 | + # remove all elements eps close |
| 179 | + idxs_to_keep = [] |
| 180 | + for j, el in enumerate(particles): |
| 181 | + if (particles[j] - x_).abs() >= eps: |
| 182 | + idxs_to_keep.append(j) |
| 183 | + |
| 184 | + particles = particles[idxs_to_keep] |
| 185 | + y = func(xvals, yvals, particles) |
| 186 | + vals, indices = y.sort(descending=True) |
| 187 | + |
| 188 | + x_ = candidate_x |
| 189 | + |
| 190 | + return torch.stack(values_x), torch.stack(values) |
| 191 | + |
| 192 | +# Define the domain resolution and bounds, number of BO iterations, |
| 193 | +# number of posterior samples, number of peaks to detect, type of detection |
| 194 | +# algorithm, and convergence criteria |
| 195 | +N = 100 |
| 196 | +low_bound = -1 |
| 197 | +upp_bound = 1 |
| 198 | + |
| 199 | +n_iter = 20 |
| 200 | +n_samples = 12 |
| 201 | +k = 3 |
| 202 | + |
| 203 | +# The normalization in x is not necessary, but it makes the convergence criteria |
| 204 | +# consistent with the BO implementation |
| 205 | +normalize_x = True |
| 206 | + |
| 207 | +# Currently accepted options are: "particle" or "offset" |
| 208 | +# topk = "particle" |
| 209 | +topk = "offset" |
| 210 | + |
| 211 | +# Initialize the system and find the top-k maxima |
| 212 | +plot_x = torch.linspace(low_bound, upp_bound, N) |
| 213 | +plot_y = f(plot_x) |
| 214 | + |
| 215 | +# Normalized |
| 216 | +if normalize_x == True: |
| 217 | + plot_x = (plot_x - low_bound)/(upp_bound - low_bound) |
| 218 | + |
| 219 | +if topk == "offset": |
| 220 | + # Choose your desired offset level, i.e., separation between the peaks |
| 221 | + topk_algo = naive_topk_eps |
| 222 | + if normalize_x == True: |
| 223 | + buffer = 0.1 |
| 224 | + else: |
| 225 | + buffer = 0.1*(upp_bound - low_bound) |
| 226 | + print("") |
| 227 | + print("--Executing a naive top-k with an offset of", buffer, "to detect the peaks--") |
| 228 | + topk_xvals, topk_yvals = naive_topk_eps(plot_x, plot_y, k, buffer) |
| 229 | +if topk == "particle": |
| 230 | + topk_algo = particle_topk |
| 231 | + print("") |
| 232 | + print("--Executing a particle search top-k to detect the peaks--") |
| 233 | + # Need a step size that is <= 1/L, where L is the largest Lipschitz constant, for convergence |
| 234 | + L = (plot_y[1:] - plot_y[:-1])/(plot_x[1:] - plot_x[:-1]) |
| 235 | + buffer = 0.5/L.max() |
| 236 | + topk_xvals, topk_yvals = particle_topk(plot_x, plot_y, k, buffer) |
| 237 | + |
| 238 | +print("") |
| 239 | +print("Actual top-k values (x-array, y-array):") |
| 240 | +print(topk_xvals.detach().numpy(), topk_yvals.detach().numpy()) |
| 241 | +print("") |
| 242 | + |
| 243 | +# Construct your training data set |
| 244 | +indx_list = [0, N//4, N//2, 3*N//4, N-1] |
| 245 | +train_X = [[plot_x[0]], [plot_x[N//4]], [plot_x[N//2]], [plot_x[3*N//4]], [plot_x[N-1]]] |
| 246 | +train_Y = [[plot_y[0]], [plot_y[N//4]], [plot_y[N//2]], [plot_y[3*N//4]], [plot_y[N-1]]] |
| 247 | +train_X = torch.tensor(train_X, dtype=dtype) |
| 248 | +train_Y = torch.tensor(train_Y, dtype=dtype) |
| 249 | + |
| 250 | +# Display an interactive plot to monitor system behavior |
| 251 | +plt.ion() |
| 252 | +fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}) |
| 253 | +ax1.plot(plot_x, plot_y, 'k-', label='f(x)') |
| 254 | +ax1.plot(topk_xvals.detach(), topk_yvals.detach(), 'r1', markersize=20, label='Actual Top-k Max') |
| 255 | +ax1.plot(train_X, train_Y, linestyle=' ', color='k', marker='o', mfc='None', markersize=8, label='Observations') |
| 256 | +ax1.set_ylim(-3, 4.5) |
| 257 | +ax1.legend(loc='upper left', frameon=False) |
| 258 | + |
| 259 | +ax2.set(xlabel='x', ylabel='EIG') |
| 260 | +plt.tight_layout() |
| 261 | +plt.pause(0.4) |
| 262 | + |
| 263 | +# Main BO iteration loop |
| 264 | +i = 1 |
| 265 | +tol = 1e-3 |
| 266 | +err = 1 |
| 267 | +noise_Y = 0.1 |
| 268 | +while i <= n_iter and err > tol: |
| 269 | + # Fitting a GP model; keep y-values unnormalized |
| 270 | + train_Yvar = torch.full_like(train_Y, noise_Y) |
| 271 | + gp = FixedNoiseGP(train_X, train_Y, train_Yvar) |
| 272 | + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) |
| 273 | + fit_gpytorch_model(mll) |
| 274 | + |
| 275 | + # Getting the posterior and some samples |
| 276 | + posterior = gp.posterior(plot_x) |
| 277 | + |
| 278 | + # Calculating the entropy of the posterior |
| 279 | + std_arr = posterior.variance.sqrt() |
| 280 | + entropy = torch.log(std_arr) + torch.log(torch.tensor(2*torch.pi).sqrt()) + 0.5 |
| 281 | + |
| 282 | + samples = posterior.sample(sample_shape=torch.Size([n_samples])) |
| 283 | + |
| 284 | + # For each sample in samples, run top_k and get x and y indices of top-k maxima |
| 285 | + # Use those along with training data to fit another gp (SingleTaskGP/FixedNoiseGP) |
| 286 | + # Find the posterior of the resulting GP, then get the standard deviation |
| 287 | + # Find the entropy using the std, append one global list, and average to get the expectation |
| 288 | + # Subtract from entropy to get the acquisiton function -> maximum = next step |
| 289 | + new_entropy = torch.zeros(N, 1) |
| 290 | + |
| 291 | + for j in range(n_samples): |
| 292 | + if topk == "particle": |
| 293 | + L = (samples[j][1:] - samples[j][:-1])/(plot_x[1:] - plot_x[:-1]) |
| 294 | + buffer = 0.5/L.max() |
| 295 | + |
| 296 | + sample_xvals, sample_yvals = topk_algo(plot_x, samples[j].squeeze(1), k, buffer) |
| 297 | + new_X = torch.cat((train_X, sample_xvals.detach().unsqueeze(1))) |
| 298 | + new_Y = torch.cat((train_Y, sample_yvals.detach().unsqueeze(1))) |
| 299 | + |
| 300 | + new_gp = SingleTaskGP(new_X, new_Y) |
| 301 | + new_mll = ExactMarginalLogLikelihood(new_gp.likelihood, new_gp) |
| 302 | + fit_gpytorch_model(new_mll) |
| 303 | + |
| 304 | + new_posterior = new_gp.posterior(plot_x) |
| 305 | + |
| 306 | + new_std = new_posterior.variance.sqrt() |
| 307 | + |
| 308 | + new_entropy += torch.log(new_std) + torch.log(torch.tensor(2*torch.pi).sqrt()) + 0.5 |
| 309 | + |
| 310 | + EIG = entropy - new_entropy/n_samples |
| 311 | + EIG_vals = torch.argsort(EIG.squeeze(1).detach(), descending=True) |
| 312 | + # Need to ensure that chosen index is not already sampled |
| 313 | + for jj in EIG_vals: |
| 314 | + if jj not in indx_list: |
| 315 | + next_indx = jj |
| 316 | + break |
| 317 | + |
| 318 | + x_next = plot_x[next_indx] |
| 319 | + if normalize_x == True: |
| 320 | + y_next = f(x_next*(upp_bound - low_bound) + low_bound) |
| 321 | + else: |
| 322 | + y_next = f(x_next) |
| 323 | + |
| 324 | + train_X = torch.cat((train_X, torch.tensor([x_next]).unsqueeze(1))) |
| 325 | + train_Y = torch.cat((train_Y, torch.tensor([y_next]).unsqueeze(1))) |
| 326 | + indx_list.append(next_indx.detach().numpy().tolist()) |
| 327 | + |
| 328 | + if i == 1: |
| 329 | + mylabel = 'Posterior' |
| 330 | + mylabel2 = 'BAX Top-k' |
| 331 | + mylabel3 = 'BO Moves' |
| 332 | + else: |
| 333 | + mylabel = None |
| 334 | + mylabel2 = None |
| 335 | + mylabel3 = None |
| 336 | + |
| 337 | + avg_posterior = posterior.mean |
| 338 | + if i == 1: |
| 339 | + avg_xvals, avg_yvals = topk_algo(plot_x, avg_posterior.squeeze(1).detach(), k, buffer) |
| 340 | + else: |
| 341 | + old_xvals, old_yvals = avg_xvals, avg_yvals |
| 342 | + avg_xvals, avg_yvals = topk_algo(plot_x, avg_posterior.squeeze(1).detach(), k, buffer) |
| 343 | + # err = (avg_xvals - old_xvals).abs().sum() |
| 344 | + err = torch.norm(avg_xvals - old_xvals) |
| 345 | + print(err) |
| 346 | + print("Iteration, next x, error, topk max") |
| 347 | + print(i, x_next.detach().numpy(), err.detach().numpy(), avg_xvals.detach().numpy()) |
| 348 | + |
| 349 | + |
| 350 | + # Plot to see the BO moves |
| 351 | + ax1.plot(x_next, y_next, 'bs', markersize=8, label=mylabel3) |
| 352 | + ax1.legend(loc='upper left', frameon=False) |
| 353 | + ax1.plot(plot_x, avg_posterior.detach(), linestyle='--', color='m', linewidth=1.0, label=mylabel) |
| 354 | + ax1.plot(avg_xvals.detach(), avg_yvals.detach(), 'y^', markersize=10, label=mylabel2) |
| 355 | + plt.draw() |
| 356 | + |
| 357 | + ax2.plot(plot_x, EIG.detach(), 'g', alpha=0.5) |
| 358 | + ax2.set(xlabel='x', ylabel='EIG') |
| 359 | + plt.draw() |
| 360 | + plt.tight_layout() |
| 361 | + plt.pause(0.25) |
| 362 | + |
| 363 | + if err > tol: |
| 364 | + plt.cla() |
| 365 | + |
| 366 | + i += 1 |
| 367 | + |
| 368 | +if i-1 != 0: |
| 369 | + print("Converged in", i-1, "iterations with error:") |
| 370 | + ov_err = np.divide((avg_xvals.detach().numpy() - topk_xvals.detach().numpy()), topk_xvals.detach().numpy()) |
| 371 | + print(np.round(abs(ov_err)*100, 1)) |
| 372 | +plt.ioff() |
| 373 | +plt.show() |
0 commit comments