Skip to content

Commit ed50129

Browse files
authored
chore: drop FAISS; cap compute (AUC, Contours, SeqLen, Subsets); handle empty tgt_data (#194)
1 parent 324e3b4 commit ed50129

File tree

6 files changed

+47
-81
lines changed

6 files changed

+47
-81
lines changed

mostlyai/qa/_distances.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import platform
1716
import time
1817
import numpy as np
1918
import networkx as nx
2019
import xxhash
20+
from sklearn.neighbors import NearestNeighbors
21+
from joblib import cpu_count
2122

2223
from mostlyai.qa._common import (
2324
CHARTS_COLORS,
@@ -42,22 +43,9 @@ def calculate_dcrs_nndrs(
4243
t0 = time.time()
4344
data = data[data[:, 0].argsort()] # sort data by first dimension to enforce deterministic results
4445

45-
if platform.system() == "Linux":
46-
# use FAISS on Linux for best performance
47-
import faiss # type: ignore
48-
49-
index = faiss.IndexFlatL2(data.shape[1])
50-
index.add(data)
51-
dcrs, _ = index.search(query, 2)
52-
dcrs = np.sqrt(dcrs) # FAISS returns squared distances
53-
else:
54-
# use sklearn as a fallback on non-Linux systems to avoid segfaults; these occurred when using QA as part of SDK
55-
from sklearn.neighbors import NearestNeighbors # type: ignore
56-
from joblib import cpu_count # type: ignore
57-
58-
index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="l2", n_jobs=min(16, max(1, cpu_count() - 1)))
59-
index.fit(data)
60-
dcrs, _ = index.kneighbors(query)
46+
index = NearestNeighbors(n_neighbors=2, algorithm="auto", metric="l2", n_jobs=min(16, max(1, cpu_count() - 1)))
47+
index.fit(data)
48+
dcrs, _ = index.kneighbors(query)
6149
dcr = dcrs[:, 0]
6250
nndr = (dcrs[:, 0] + 1e-8) / (dcrs[:, 1] + 1e-8)
6351
_LOG.info(f"calculated DCRs for {data.shape=} and {query.shape=} in {time.time() - t0:.2f}s")
@@ -85,14 +73,12 @@ def calculate_distances(
8573
groups = []
8674
# check all columns together
8775
groups += [np.arange(ori_embeds.shape[1])]
88-
# check subsets of correlated columns together
76+
# check 3 correlated subsets of columns
8977
if ori_embeds.shape[1] > 10:
90-
k = max(3, ori_embeds.shape[1] // 10)
91-
groups += split_columns_into_correlated_groups(ori_embeds, k=k)
92-
# check random subsets of columns
78+
groups += split_columns_into_correlated_groups(ori_embeds, k=3)
79+
# check 3 random subsets of columns
9380
if ori_embeds.shape[1] > 10:
94-
k = max(3, ori_embeds.shape[1] // 10)
95-
groups += split_columns_into_random_groups(ori_embeds, k=k)
81+
groups += split_columns_into_random_groups(ori_embeds, k=3)
9682
dcr_share = 0.0
9783
nndr_ratio = 1.0
9884
for columns in groups:

mostlyai/qa/_sampling.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,13 @@ def prepare_data_for_embeddings(
300300

301301
# cap to Q95 sequence length of original to avoid excessive samples per group distorting results
302302
if tgt_context_key is not None:
303+
cap_sequence_length = 100
303304
q95_sequence_length = trn_tgt_data.groupby(key).size().quantile(0.95)
304-
syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length)
305-
trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length)
305+
max_sequence_length = min(q95_sequence_length, cap_sequence_length)
306+
syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length)
307+
trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length)
306308
hol_tgt_data = (
307-
hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) if hol else None
309+
hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=max_sequence_length) if hol else None
308310
)
309311

310312
# drop key from data as its not relevant for embeddings

mostlyai/qa/_similarity.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def calculate_mean_auc(embeds1, embeds2):
6969
for a ML model to discriminate between two embedding arrays.
7070
"""
7171

72+
# limit the number of samples to 10000
73+
embeds1 = embeds1[:10000]
74+
embeds2 = embeds2[:10000]
75+
7276
# create labels for the data
7377
labels1 = np.zeros(embeds1.shape[0])
7478
labels2 = np.ones(embeds2.shape[0])
@@ -195,6 +199,11 @@ def plot_store_similarity_contours(
195199
if trn_embeds.shape[1] < 3:
196200
return
197201

202+
# limit the number of samples to 10000
203+
syn_embeds = syn_embeds[:10000]
204+
trn_embeds = trn_embeds[:10000]
205+
hol_embeds = hol_embeds[:10000] if hol_embeds is not None else None
206+
198207
# perform PCA on trn embeddings
199208
pca_model = PCA(n_components=3)
200209
pca_model.fit(trn_embeds)

mostlyai/qa/reporting.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def report(
181181
check_min_sample_size(trn_sample_size, 90, "training")
182182
if hol_tgt_data is not None:
183183
check_min_sample_size(hol_sample_size, 10, "holdout")
184+
if trn_tgt_data.shape[1] == 0 or syn_tgt_data.shape[1] == 0:
185+
raise PrerequisiteNotMetError("Provided data has no columns.")
184186
except PrerequisiteNotMetError as err:
185187
_LOG.info(err)
186188
statistics.mark_early_exit()
@@ -205,7 +207,6 @@ def report(
205207
else:
206208
setup = "1:1"
207209

208-
_LOG.info("prepare training data for accuracy")
209210
trn = prepare_data_for_accuracy(
210211
df_tgt=trn_tgt_data,
211212
df_ctx=trn_ctx_data,
@@ -214,8 +215,8 @@ def report(
214215
max_sample_size=max_sample_size_accuracy,
215216
setup=setup,
216217
)
218+
_LOG.info(f"prepared training data for accuracy: {trn.shape}")
217219
if hol_tgt_data is not None:
218-
_LOG.info("prepare holdout data for accuracy")
219220
hol = prepare_data_for_accuracy(
220221
df_tgt=hol_tgt_data,
221222
df_ctx=hol_ctx_data,
@@ -225,13 +226,13 @@ def report(
225226
setup=setup,
226227
ori_dtypes=trn.dtypes.to_dict(),
227228
)
229+
_LOG.info(f"prepared holdout data for accuracy: {hol.shape}")
228230
ori = pd.concat([trn, hol], axis=0, ignore_index=True)
229231
else:
230232
hol = None
231233
ori = trn
232234
progress.update(completed=5, total=100)
233235

234-
_LOG.info("prepare synthetic data for accuracy")
235236
syn = prepare_data_for_accuracy(
236237
df_tgt=syn_tgt_data,
237238
df_ctx=syn_ctx_data,
@@ -241,29 +242,29 @@ def report(
241242
setup=setup,
242243
ori_dtypes=trn.dtypes.to_dict(),
243244
)
245+
_LOG.info(f"prepared synthetic data for accuracy: {syn.shape}")
244246
progress.update(completed=10, total=100)
245247

246248
# do coherence analysis only if there are non-fk columns in the target data
247249
do_coherence = setup == "1:N" and len(trn_tgt_data.columns) > 1
248250
if do_coherence:
249-
_LOG.info("prepare original data for coherence started")
250251
ori_coh, ori_coh_bins = prepare_data_for_coherence(
251252
df_tgt=pd.concat([trn_tgt_data, hol_tgt_data]) if hol_tgt_data is not None else trn_tgt_data,
252253
tgt_context_key=tgt_context_key,
253254
max_sample_size=max_sample_size_coherence,
254255
)
255-
_LOG.info("prepare synthetic data for coherence started")
256+
_LOG.info(f"prepared original data for coherence: {ori_coh.shape}")
256257
syn_coh, _ = prepare_data_for_coherence(
257258
df_tgt=syn_tgt_data,
258259
tgt_context_key=tgt_context_key,
259260
bins=ori_coh_bins,
260261
max_sample_size=max_sample_size_coherence,
261262
)
262-
_LOG.info("store bins used for training data for coherence")
263+
_LOG.info(f"prepared synthetic data for coherence: {syn_coh.shape}")
263264
statistics.store_coherence_bins(bins=ori_coh_bins)
265+
_LOG.info("stored bins used for training data for coherence")
264266
progress.update(completed=15, total=100)
265267

266-
_LOG.info("calculate embeddings")
267268
syn_embeds, trn_embeds, hol_embeds = prepare_data_for_embeddings(
268269
syn_tgt_data=syn_tgt_data,
269270
trn_tgt_data=trn_tgt_data,
@@ -275,6 +276,9 @@ def report(
275276
tgt_context_key=tgt_context_key,
276277
max_sample_size=max_sample_size_embeddings,
277278
)
279+
_LOG.info(
280+
f"calculated embeddings: syn={syn_embeds.shape}, trn={trn_embeds.shape}, hol={hol_embeds.shape if hol_embeds is not None else None}"
281+
)
278282
progress.update(completed=20, total=100)
279283

280284
## 1. ACCURACY ##

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ dependencies = [
3939
"accelerate>=1.5.0",
4040
"torch>=2.6.0",
4141
"xxhash>=3.5.0",
42-
"faiss-cpu>=1.7.0",
4342
]
4443

4544
[project.urls]

0 commit comments

Comments
 (0)