|
34 | 34 | import pandas as pd
|
35 | 35 | import pyarrow as pa
|
36 | 36 |
|
| 37 | +from mostlyai.qa._embeddings import encode_data |
37 | 38 | from mostlyai.qa._accuracy import bin_data
|
38 | 39 | from mostlyai.qa._common import (
|
39 | 40 | CTX_COLUMN_PREFIX,
|
| 41 | + EMBEDDINGS_MAX_COLUMNS, |
40 | 42 | TGT_COLUMN_PREFIX,
|
41 | 43 | NXT_COLUMN_PREFIX,
|
42 | 44 | COUNT_COLUMN,
|
@@ -131,7 +133,7 @@ def prepare_data_for_accuracy(
|
131 | 133 | # harmonize dtypes
|
132 | 134 | df = df.apply(harmonize_dtype)
|
133 | 135 |
|
134 |
| - # coerce dtypes to trn_dtypes |
| 136 | + # coerce dtypes to ori_dtypes |
135 | 137 | for trn_col, trn_dtype in (ori_dtypes or {}).items():
|
136 | 138 | if is_numeric_dtype(trn_dtype):
|
137 | 139 | df[trn_col] = pd.to_numeric(df[trn_col], errors="coerce")
|
@@ -262,3 +264,82 @@ def is_timestamp_dtype(x: pd.Series) -> bool:
|
262 | 264 | def is_text_heuristic(x: pd.Series) -> bool:
|
263 | 265 | # if more than 5% of rows contain unique values -> consider as TEXT
|
264 | 266 | return x.dtype == "object" and x.value_counts().eq(1).reindex(x).mean() > 0.05
|
| 267 | + |
| 268 | + |
| 269 | +def prepare_data_for_embeddings( |
| 270 | + *, |
| 271 | + syn_tgt_data: pd.DataFrame, |
| 272 | + trn_tgt_data: pd.DataFrame, |
| 273 | + hol_tgt_data: pd.DataFrame | None = None, |
| 274 | + syn_ctx_data: pd.DataFrame | None = None, |
| 275 | + trn_ctx_data: pd.DataFrame | None = None, |
| 276 | + hol_ctx_data: pd.DataFrame | None = None, |
| 277 | + ctx_primary_key: str | None = None, |
| 278 | + tgt_context_key: str | None = None, |
| 279 | + max_sample_size: int | None = None, |
| 280 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: |
| 281 | + # helper variables |
| 282 | + key = tgt_context_key or None |
| 283 | + hol = hol_tgt_data is not None |
| 284 | + |
| 285 | + # filter target to context keys |
| 286 | + if trn_ctx_data is not None: |
| 287 | + rename_key = {ctx_primary_key: key} |
| 288 | + syn_ctx_data = syn_ctx_data[[ctx_primary_key]].rename(columns=rename_key) |
| 289 | + trn_ctx_data = trn_ctx_data[[ctx_primary_key]].rename(columns=rename_key) |
| 290 | + hol_ctx_data = hol_ctx_data[[ctx_primary_key]].rename(columns=rename_key) if hol else None |
| 291 | + syn_tgt_data = syn_tgt_data.merge(syn_ctx_data, on=key, how="inner") |
| 292 | + trn_tgt_data = trn_tgt_data.merge(trn_ctx_data, on=key, how="inner") |
| 293 | + hol_tgt_data = hol_tgt_data.merge(hol_ctx_data, on=key, how="inner") if hol else None |
| 294 | + |
| 295 | + # enrich with count column |
| 296 | + if tgt_context_key is not None: |
| 297 | + syn_tgt_data.insert(0, COUNT_COLUMN, syn_tgt_data.groupby(key)[key].transform("size")) |
| 298 | + trn_tgt_data.insert(0, COUNT_COLUMN, trn_tgt_data.groupby(key)[key].transform("size")) |
| 299 | + hol_tgt_data.insert(0, COUNT_COLUMN, hol_tgt_data.groupby(key)[key].transform("size")) if hol else None |
| 300 | + |
| 301 | + # cap to Q95 sequence length of original to avoid excessive samples per group distorting results |
| 302 | + if tgt_context_key is not None: |
| 303 | + q95_sequence_length = trn_tgt_data.groupby(key).size().quantile(0.95) |
| 304 | + syn_tgt_data = syn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) |
| 305 | + trn_tgt_data = trn_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) |
| 306 | + hol_tgt_data = ( |
| 307 | + hol_tgt_data.groupby(key).sample(frac=1).groupby(key).head(n=q95_sequence_length) if hol else None |
| 308 | + ) |
| 309 | + |
| 310 | + # drop key from data as its not relevant for embeddings |
| 311 | + if tgt_context_key is not None: |
| 312 | + syn_tgt_data = syn_tgt_data.drop(columns=[key]) |
| 313 | + trn_tgt_data = trn_tgt_data.drop(columns=[key]) |
| 314 | + hol_tgt_data = hol_tgt_data.drop(columns=[key]) if hol else None |
| 315 | + |
| 316 | + # draw equally sized samples for fair 3-way comparison |
| 317 | + max_sample_size = min( |
| 318 | + max_sample_size or float("inf"), |
| 319 | + len(syn_tgt_data), |
| 320 | + len(trn_tgt_data), |
| 321 | + len(hol_tgt_data) if hol_tgt_data is not None else float("inf"), |
| 322 | + ) |
| 323 | + syn_tgt_data = syn_tgt_data.sample(n=max_sample_size) |
| 324 | + trn_tgt_data = trn_tgt_data.sample(n=max_sample_size) |
| 325 | + hol_tgt_data = hol_tgt_data.sample(n=max_sample_size) if hol else None |
| 326 | + |
| 327 | + # limit to same columns |
| 328 | + trn_cols = list(trn_tgt_data.columns)[:EMBEDDINGS_MAX_COLUMNS] |
| 329 | + syn_tgt_data = syn_tgt_data[trn_cols] |
| 330 | + trn_tgt_data = trn_tgt_data[trn_cols] |
| 331 | + hol_tgt_data = hol_tgt_data[trn_cols] if hol else None |
| 332 | + |
| 333 | + # harmonize dtypes |
| 334 | + syn_tgt_data = syn_tgt_data.apply(harmonize_dtype) |
| 335 | + trn_tgt_data = trn_tgt_data.apply(harmonize_dtype) |
| 336 | + hol_tgt_data = hol_tgt_data.apply(harmonize_dtype) if hol else None |
| 337 | + |
| 338 | + # encode data |
| 339 | + syn_embeds, trn_embeds, hol_embeds = encode_data( |
| 340 | + syn_data=syn_tgt_data, |
| 341 | + trn_data=trn_tgt_data, |
| 342 | + hol_data=hol_tgt_data, |
| 343 | + ) |
| 344 | + |
| 345 | + return syn_embeds, trn_embeds, hol_embeds |
0 commit comments