-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfrontier_stitching.py
executable file
·66 lines (54 loc) · 2 KB
/
frontier_stitching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
from helpers import binomial
def fast_gradient_signed(x, y, model, eps):
with tf.GradientTape() as tape:
tape.watch(x)
y_pred = model(x)
loss = model.loss(y, y_pred)
gradient = tape.gradient(loss, x)
sign = tf.sign(gradient)
return x + eps * sign
def gen_adversaries(model, l, dataset, eps):
true_advs = []
false_advs = []
max_true_advs = max_false_advs = l // 2
for x, y in dataset:
# generate adversaries
x_advs = fast_gradient_signed(x, y, model, eps)
y_preds = tf.argmax(model(x), axis=1)
y_pred_advs = tf.argmax(model(x_advs), axis=1)
for x_adv, y_pred_adv, y_pred, y_true in zip(x_advs, y_pred_advs, y_preds, y):
# x_adv is a true adversary
if y_pred == y_true and y_pred_adv != y_true and len(true_advs) < max_true_advs:
true_advs.append((x_adv, y_true))
# x_adv is a false adversary
if y_pred == y_true and y_pred_adv == y_true and len(false_advs) < max_false_advs:
false_advs.append((x_adv, y_true))
if len(true_advs) == max_true_advs and len(false_advs) == max_false_advs:
return true_advs, false_advs
return true_advs, false_advs
# finds a value for theta (maximum number of errors tolerated for verification)
def find_tolerance(key_length, threshold):
theta = 0
factor = 2 ** (-key_length)
s = 0
while(True):
# for z in range(theta + 1):
s += binomial(key_length, theta)
if factor * s >= threshold:
return theta
theta += 1
def verify(model, key_set, threshold=0.05):
m_k = 0
length = 0
for x, y in key_set:
length += len(x)
preds = tf.argmax(model(x), axis=1)
m_k += tf.reduce_sum(tf.cast(preds != y, tf.int32))
theta = find_tolerance(length, threshold)
m_k = m_k.numpy()
return {
"success": m_k < theta,
"false_preds": m_k,
"max_fals_pred_tolerance": theta
}