Skip to content

Commit ff02519

Browse files
committed
Added normalization option; BAX does not improve with it
1 parent e9f1b1d commit ff02519

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

scripts/KY_topk.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
158158

159159
v = 0
160160
while True:
161-
if v >= len(particles) or v >= len(vals) or v >= len(indices):
161+
if v >= len(particles) or v >= len(vals) or v >= len(indices):
162162
print("Breaking early, not enough significant top-k elements...")
163163
return torch.stack(values), torch.stack(values_x)
164164

@@ -192,26 +192,34 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
192192
# Define the domain resolution and bounds, number of BO iterations,
193193
# number of posterior samples, number of peaks to detect, type of detection
194194
# algorithm, and convergence criteria
195-
N = 100
195+
N = 1000
196196
low_bound = -1
197197
upp_bound = 1
198198

199-
n_iter = 10
199+
n_iter = 20
200200
n_samples = 12
201201
k = 2
202+
normalize = False
202203

203204
# Currently accepted options are: "particle" or "offset"
204-
topk = "particle"
205+
# topk = "particle"
205206
topk = "offset"
206207

207208
# Initialize the system and find the top-k maxima
208209
plot_x = torch.linspace(low_bound, upp_bound, N)
209210
plot_y = f(plot_x)
210211

212+
# Normalized
213+
if normalize == True:
214+
plot_x = (plot_x - low_bound)/(upp_bound - low_bound)
215+
211216
if topk == "offset":
212217
# Choose your desired offset level, i.e., separation between the peaks
213218
topk_algo = naive_topk_eps
214-
buffer = 0.5
219+
if normalize == True:
220+
buffer = 0.1
221+
else:
222+
buffer = 0.1*(upp_bound - low_bound)
215223
print("")
216224
print("--Executing a naive top-k with an offset of", buffer, "to detect the peaks--")
217225
topk_xvals, topk_yvals = naive_topk_eps(plot_x, plot_y, k, buffer)
@@ -251,9 +259,12 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
251259

252260
# Main BO iteration loop
253261
i = 1
254-
tol = 1e-2
262+
tol = 1e-4
255263
err = 1
256-
while i <= n_iter and err > tol:
264+
while (i <= n_iter and err > tol):
265+
if normalize == True:
266+
train_Y = (train_Y - train_Y.mean())/train_Y.std()
267+
257268
# Fitting a GP model
258269
gp = SingleTaskGP(train_X, train_Y)
259270
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
@@ -277,10 +288,12 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
277288

278289
for j in range(n_samples):
279290
sample_xvals, sample_yvals = topk_algo(plot_x, samples[j].squeeze(1), k, buffer)
280-
# print(sample_xvals.detach(), sample_yvals)
281291
new_X = torch.cat((train_X, sample_xvals.detach().unsqueeze(1)))
282292
new_Y = torch.cat((train_Y, sample_yvals.detach().unsqueeze(1)))
283293

294+
if normalize == True:
295+
new_Y = (new_Y - new_Y.mean())/new_Y.std()
296+
284297
new_gp = SingleTaskGP(new_X, new_Y)
285298
new_mll = ExactMarginalLogLikelihood(new_gp.likelihood, new_gp)
286299
fit_gpytorch_model(new_mll)
@@ -300,7 +313,10 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
300313
break
301314

302315
x_next = plot_x[next_indx]
303-
y_next = f(x_next)
316+
if normalize == True:
317+
y_next = f(x_next*(upp_bound - low_bound) + low_bound)
318+
else:
319+
y_next = f(x_next)
304320

305321
train_X = torch.cat((train_X, torch.tensor([x_next]).unsqueeze(1)))
306322
train_Y = torch.cat((train_Y, torch.tensor([y_next]).unsqueeze(1)))
@@ -344,6 +360,9 @@ def particle_topk(xvals, yvals, k, lr, func=interp, n_particles=1000, inner_iter
344360

345361
i += 1
346362

347-
print("Converged in", i, "iterations")
363+
if i-1 != 0:
364+
print("Converged in", i-1, "iterations with error:")
365+
ov_err = np.divide((avg_xvals.detach().numpy() - topk_xvals.detach().numpy()), topk_xvals.detach().numpy())
366+
print(np.round(abs(ov_err)*100, 1))
348367
plt.ioff()
349368
plt.show()

0 commit comments

Comments
 (0)