-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_cake.py
125 lines (97 loc) · 3.4 KB
/
data_cake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
import torch
# ---
# Data is modifiable here
# ---
NUMBER_OF_SECTORS = None # value comes from pennylane tester or from main program
np.random.seed(1358)
# ---
# CREATING AND PLOTTING DATA FUNCTIONS (from paper.py)
# ---
def _make_circular_data(num_sectors):
"""Generate datapoints arranged in an even circle."""
center_indices = np.array(range(0, num_sectors))
sector_angle = 2 * np.pi / num_sectors
angles = (center_indices + 0.5) * sector_angle
x = 0.7 * np.cos(angles)
y = 0.7 * np.sin(angles)
labels = 2 * np.remainder(np.floor_divide(angles, sector_angle), 2) - 1
return x, y, labels
def make_double_cake_data(num_sectors):
x1, y1, labels1 = _make_circular_data(num_sectors)
x2, y2, labels2 = _make_circular_data(num_sectors)
# x and y coordinates of the datapoints
x = np.hstack([x1, 0.5 * x2])
y = np.hstack([y1, 0.5 * y2])
# Canonical form of dataset
xx = np.vstack([x, y]).T
labels = np.hstack([labels1, -1 * labels2])
# Canonical form of labels
yy = labels.astype(int)
return xx, yy
def plot_double_cake_data(xx, yy, ax, num_sectors=None):
"""Plot double cake data and corresponding sectors."""
x, y = xx.T
cmap = mpl.colors.ListedColormap(["#FF0000", "#0000FF"])
ax.scatter(x, y, c=yy, cmap=cmap, s=25, marker="s")
if num_sectors is not None:
sector_angle = 360 / num_sectors
for i in range(num_sectors):
color = ["#FF0000", "#0000FF"][(i % 2)]
other_color = ["#FF0000", "#0000FF"][((i + 1) % 2)]
ax.add_artist(
mpl.patches.Wedge(
(0, 0),
1,
i * sector_angle,
(i + 1) * sector_angle,
lw=0,
color=color,
alpha=0.1,
width=0.5,
)
)
ax.add_artist(
mpl.patches.Wedge(
(0, 0),
0.5,
i * sector_angle,
(i + 1) * sector_angle,
lw=0,
color=other_color,
alpha=0.1,
)
)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_aspect("equal")
ax.axis("off")
return ax
def plot_decision_boundaries(classifier, ax, x, y, N_gridpoints=14):
_xx, _yy = np.meshgrid(np.linspace(-1, 1, N_gridpoints), np.linspace(-1, 1, N_gridpoints))
_zz = np.zeros_like(_xx)
for idx in np.ndindex(*_xx.shape):
_zz[idx] = classifier.predict(np.array([_xx[idx], _yy[idx]])[np.newaxis, :])
plot_data = {"_xx": _xx, "_yy": _yy, "_zz": _zz}
ax.contourf(
_xx,
_yy,
_zz,
cmap=mpl.colors.ListedColormap(["#FF0000", "#0000FF"]),
alpha=0.2,
levels=[-1, 0, 1],
)
plot_double_cake_data(x, y, ax)
return plot_data
# ---
# DATESET FUNCTION TO INITIALIZE AND TO CONVERT DATA INTO TRAINABLE PYTORCH TENSORS
#
# ---
def data():
(x, y) = make_double_cake_data(NUMBER_OF_SECTORS)
# plotting, comment out show to not get a window of the data
# ax = plot_double_cake_data(x, y, plt.gca(), NUMBER_OF_SECTORS)
# plt.show()
return torch.tensor(x, dtype=torch.float64, requires_grad=True), torch.tensor(y, dtype=torch.float64, requires_grad=True)