Skip to content

Commit e24c658

Browse files
authored
chore: dedicated data pull for embeddings to fetch all events for sequential data (#192)
1 parent a6b1b23 commit e24c658

File tree

6 files changed

+107
-45
lines changed

6 files changed

+107
-45
lines changed

mostlyai/qa/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
ACCURACY_MAX_COLUMNS = 300 # should be an even number and greater than 100
26+
EMBEDDINGS_MAX_COLUMNS = 300
2627

2728
MAX_UNIVARIATE_PLOTS = 300
2829
MAX_BIVARIATE_TGT_PLOTS = 300

mostlyai/qa/_embeddings.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import numpy as np
1617
import pandas as pd
1718
from sklearn.decomposition import PCA
1819
from sklearn.preprocessing import QuantileTransformer, normalize
20+
from pandas.core.dtypes.common import is_numeric_dtype, is_datetime64_dtype
1921

2022
from mostlyai.qa._common import (
21-
COUNT_COLUMN,
2223
EMPTY_BIN,
2324
NA_BIN,
24-
NXT_COLUMN_PREFIX,
2525
RARE_BIN,
26-
TGT_COLUMN_PREFIX,
2726
)
2827
from mostlyai.qa.assets import load_embedder
2928

3029

30+
_LOG = logging.getLogger(__name__)
31+
32+
3133
def encode_numerics(
3234
syn: pd.DataFrame, trn: pd.DataFrame, hol: pd.DataFrame | None = None
3335
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame | None]:
@@ -128,36 +130,15 @@ def encode_strings(
128130

129131

130132
def encode_data(
131-
syn: pd.DataFrame, trn: pd.DataFrame, hol: pd.DataFrame | None = None
133+
syn_data: pd.DataFrame, trn_data: pd.DataFrame, hol_data: pd.DataFrame | None = None
132134
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
133135
"""
134136
Encode all columns corresponding to their data type.
135137
"""
136-
tgt_cols = [c for c in trn.columns if c.startswith(TGT_COLUMN_PREFIX)]
137-
nxt_cols = [c for c in trn.columns if c.startswith(NXT_COLUMN_PREFIX)]
138-
cnt_col = f"{TGT_COLUMN_PREFIX}{COUNT_COLUMN}"
139-
cnt_cols = [cnt_col] if cnt_col in trn.columns else []
138+
_LOG.info("encode datasets for embeddings")
140139
# split into numeric and string columns
141-
num_dat_cols = [
142-
col for col in tgt_cols if pd.api.types.is_numeric_dtype(trn[col]) or pd.api.types.is_datetime64_dtype(trn[col])
143-
]
144-
string_cols = [col for col in tgt_cols if col not in num_dat_cols]
145-
# keep TGT data
146-
syn_data = syn[tgt_cols]
147-
trn_data = trn[tgt_cols]
148-
hol_data = hol[tgt_cols] if hol is not None else None
149-
# append NXT data with TGT data to increase data coverage
150-
if len(nxt_cols) > 0:
151-
syn_nxt = syn[cnt_cols + nxt_cols]
152-
syn_nxt.columns = syn_nxt.columns.str.replace(NXT_COLUMN_PREFIX, TGT_COLUMN_PREFIX)
153-
syn_data = pd.concat([syn_data, syn_nxt], axis=0)
154-
trn_nxt = trn[cnt_cols + nxt_cols]
155-
trn_nxt.columns = trn_nxt.columns.str.replace(NXT_COLUMN_PREFIX, TGT_COLUMN_PREFIX)
156-
trn_data = pd.concat([trn_data, trn_nxt], axis=0)
157-
if hol is not None:
158-
hol_nxt = hol[cnt_cols + nxt_cols]
159-
hol_nxt.columns = hol_nxt.columns.str.replace(NXT_COLUMN_PREFIX, TGT_COLUMN_PREFIX)
160-
hol_data = pd.concat([hol_data, hol_nxt], axis=0)
140+
num_dat_cols = [col for col in trn_data if is_numeric_dtype(trn_data[col]) or is_datetime64_dtype(trn_data[col])]
141+
string_cols = [col for col in trn_data if col not in num_dat_cols]
161142
# encode numeric columns
162143
syn_num, trn_num, hol_num = encode_numerics(
163144
syn_data[num_dat_cols], trn_data[num_dat_cols], hol_data[num_dat_cols] if hol_data is not None else None

mostlyai/qa/_sampling.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
import pandas as pd
3535
import pyarrow as pa
3636

37+
from mostlyai.qa._embeddings import encode_data
3738
from mostlyai.qa._accuracy import bin_data
3839
from mostlyai.qa._common import (
3940
CTX_COLUMN_PREFIX,
41+
EMBEDDINGS_MAX_COLUMNS,
4042
TGT_COLUMN_PREFIX,
4143
NXT_COLUMN_PREFIX,
4244
COUNT_COLUMN,
@@ -131,7 +133,7 @@ def prepare_data_for_accuracy(
131133
# harmonize dtypes
132134
df = df.apply(harmonize_dtype)
133135

134-
# coerce dtypes to trn_dtypes
136+
# coerce dtypes to ori_dtypes
135137
for trn_col, trn_dtype in (ori_dtypes or {}).items():
136138
if is_numeric_dtype(trn_dtype):
137139
df[trn_col] = pd.to_numeric(df[trn_col], errors="coerce")
@@ -262,3 +264,82 @@ def is_timestamp_dtype(x: pd.Series) -> bool:
262264
def is_text_heuristic(x: pd.Series) -> bool:
263265
# if more than 5% of rows contain unique values -> consider as TEXT
264266
return x.dtype == "object" and x.value_counts().eq(1).reindex(x).mean() > 0.05
267+
268+
269+
def prepare_data_for_embeddings(
270+
*,
271+
syn_tgt_data: pd.DataFrame,
272+
trn_tgt_data: pd.DataFrame,
273+
hol_tgt_data: pd.DataFrame | None = None,
274+
syn_ctx_data: pd.DataFrame | None = None,
275+
trn_ctx_data: pd.DataFrame | None = None,
276+
hol_ctx_data: pd.DataFrame | None = None,
277+
ctx_primary_key: str | None = None,
278+
tgt_context_key: str | None = None,
279+
max_sample_size: int | None = None,
280+
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
281+
# helper variables
282+
key = tgt_context_key or None
283+
hol = hol_tgt_data is not None
284+
285+
# filter target to context keys
286+
if trn_ctx_data is not None:
287+
rename_key = {ctx_primary_key: key}
288+
syn_ctx_data = syn_ctx_data[[ctx_primary_key]].rename(columns=rename_key)
289+
trn_ctx_data = trn_ctx_data[[ctx_primary_key]].rename(columns=rename_key)
290+
hol_ctx_data = hol_ctx_data[[ctx_primary_key]].rename(columns=rename_key) if hol else None
291+
syn_tgt_data = syn_tgt_data.merge(syn_ctx_data, on=key, how="inner")
292+
trn_tgt_data = trn_tgt_data.merge(trn_ctx_data, on=key, how="inner")
293+
hol_tgt_data = hol_tgt_data.merge(hol_ctx_data, on=key, how="inner") if hol else None
294+
295+
# enrich with count column
296+
if tgt_context_key is not None:
297+
syn_tgt_data.insert(0, COUNT_COLUMN, syn_tgt_data.groupby(key)[key].transform("size"))
298+
trn_tgt_data.insert(0, COUNT_COLUMN, trn_tgt_data.groupby(key)[key].transform("size"))
299+
hol_tgt_data.insert(0, COUNT_COLUMN, hol_tgt_data.groupby(key)[key].transform("size")) if hol else None
300+
301+
# cap to Q95 sequence length of original to avoid excessive samples per group distorting results
302+
if tgt_context_key is not None:
303+
q95_sequence_length = trn_tgt_data.groupby(key).size().quantile(0.95)
304+
syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length)
305+
trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length)
306+
hol_tgt_data = (
307+
hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) if hol else None
308+
)
309+
310+
# drop key from data as its not relevant for embeddings
311+
if tgt_context_key is not None:
312+
syn_tgt_data = syn_tgt_data.drop(columns=[key])
313+
trn_tgt_data = trn_tgt_data.drop(columns=[key])
314+
hol_tgt_data = hol_tgt_data.drop(columns=[key]) if hol else None
315+
316+
# draw equally sized samples for fair 3-way comparison
317+
max_sample_size = min(
318+
max_sample_size or float("inf"),
319+
len(syn_tgt_data),
320+
len(trn_tgt_data),
321+
len(hol_tgt_data) if hol_tgt_data is not None else float("inf"),
322+
)
323+
syn_tgt_data = syn_tgt_data.sample(n=max_sample_size)
324+
trn_tgt_data = trn_tgt_data.sample(n=max_sample_size)
325+
hol_tgt_data = hol_tgt_data.sample(n=max_sample_size) if hol else None
326+
327+
# limit to same columns
328+
trn_cols = list(trn_tgt_data.columns)[:EMBEDDINGS_MAX_COLUMNS]
329+
syn_tgt_data = syn_tgt_data[trn_cols]
330+
trn_tgt_data = trn_tgt_data[trn_cols]
331+
hol_tgt_data = hol_tgt_data[trn_cols] if hol else None
332+
333+
# harmonize dtypes
334+
syn_tgt_data = syn_tgt_data.apply(harmonize_dtype)
335+
trn_tgt_data = trn_tgt_data.apply(harmonize_dtype)
336+
hol_tgt_data = hol_tgt_data.apply(harmonize_dtype) if hol else None
337+
338+
# encode data
339+
syn_embeds, trn_embeds, hol_embeds = encode_data(
340+
syn_data=syn_tgt_data,
341+
trn_data=trn_tgt_data,
342+
hol_data=hol_tgt_data,
343+
)
344+
345+
return syn_embeds, trn_embeds, hol_embeds

mostlyai/qa/reporting.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pandas as pd
2121
from pandas.core.dtypes.common import is_numeric_dtype, is_datetime64_dtype
2222

23-
from mostlyai.qa import _distances, _similarity, _html_report, _embeddings
23+
from mostlyai.qa import _distances, _similarity, _html_report
2424
from mostlyai.qa._accuracy import (
2525
bin_data,
2626
binning_data,
@@ -50,6 +50,7 @@
5050
from mostlyai.qa._sampling import (
5151
prepare_data_for_accuracy,
5252
prepare_data_for_coherence,
53+
prepare_data_for_embeddings,
5354
)
5455
from mostlyai.qa._common import (
5556
determine_data_size,
@@ -263,17 +264,16 @@ def report(
263264
progress.update(completed=15, total=100)
264265

265266
_LOG.info("calculate embeddings")
266-
# ensure that embeddings are all equal size for a fair 3-way comparison
267-
max_sample_size_embeddings_final = min(
268-
max_sample_size_embeddings or float("inf"),
269-
syn_sample_size,
270-
trn_sample_size,
271-
hol_sample_size or float("inf"),
272-
)
273-
syn_embeds, trn_embeds, hol_embeds = _embeddings.encode_data(
274-
syn=syn.head(max_sample_size_embeddings_final),
275-
trn=trn.head(max_sample_size_embeddings_final),
276-
hol=hol.head(max_sample_size_embeddings_final) if hol is not None else None,
267+
syn_embeds, trn_embeds, hol_embeds = prepare_data_for_embeddings(
268+
syn_tgt_data=syn_tgt_data,
269+
trn_tgt_data=trn_tgt_data,
270+
hol_tgt_data=hol_tgt_data,
271+
syn_ctx_data=syn_ctx_data,
272+
trn_ctx_data=trn_ctx_data,
273+
hol_ctx_data=hol_ctx_data,
274+
ctx_primary_key=ctx_primary_key,
275+
tgt_context_key=tgt_context_key,
276+
max_sample_size=max_sample_size_embeddings,
277277
)
278278
progress.update(completed=20, total=100)
279279

tests/end_to_end/test_report.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,4 +304,3 @@ def generate_dates(start_date, end_date, num_samples):
304304
"Expected a warning about dtype mismatch for column 'dt'"
305305
)
306306
assert statistics.accuracy.overall > 0.6
307-
assert 0.2 < statistics.similarity.discriminator_auc_training_synthetic < 0.8

tests/unit/test_html_report.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_generate_store_report(tmp_path, cols, workspace):
3434
acc_seqs_per_cat = pd.DataFrame({"column": acc_uni["column"], "accuracy": 0.5, "accuracy_max": 0.5})
3535
corr_trn = _accuracy.calculate_correlations(acc_trn)
3636
syn_embeds, trn_embeds, hol_embeds = _embeddings.encode_data(
37-
syn=syn,
38-
trn=trn,
39-
hol=hol,
37+
syn_data=syn,
38+
trn_data=trn,
39+
hol_data=hol,
4040
)
4141
sim_cosine_trn_hol, sim_cosine_trn_syn = _similarity.calculate_cosine_similarities(
4242
syn_embeds=syn_embeds,

0 commit comments

Comments
 (0)