Skip to content

Commit 7058be3

Browse files
feat: reproducibility (#197)
1 parent d9790c1 commit 7058be3

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

mostlyai/qa/_common.py

+20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
17+
import struct
1618
from typing import Protocol
1719

1820
import pandas as pd
@@ -122,3 +124,21 @@ def determine_data_size(
122124
return len(tgt_keys)
123125
else:
124126
return len(tgt_data)
127+
128+
129+
def set_random_state(random_state: int | None = None):
130+
def get_random_int_from_os() -> int:
131+
# 32-bit, cryptographically secure random int from os
132+
return int(struct.unpack("I", os.urandom(4))[0])
133+
134+
if random_state is not None:
135+
_LOG.info(f"Global random_state set to `{random_state}`")
136+
137+
if random_state is None:
138+
random_state = get_random_int_from_os()
139+
140+
import random
141+
import numpy as np
142+
143+
random.seed(random_state)
144+
np.random.seed(random_state)

mostlyai/qa/reporting.py

+5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
TGT_COLUMN_PREFIX,
6363
REPORT_CREDITS,
6464
ProgressCallbackWrapper,
65+
set_random_state,
6566
)
6667
from mostlyai.qa._filesystem import Statistics, TemporaryWorkspace
6768

@@ -87,6 +88,7 @@ def report(
8788
max_sample_size_embeddings: int | None = None,
8889
statistics_path: str | Path | None = None,
8990
update_progress: ProgressCallback | None = None,
91+
random_state: int | None = None,
9092
) -> tuple[Path, ModelMetrics | None]:
9193
"""
9294
Generate an HTML report and metrics for assessing synthetic data quality.
@@ -121,12 +123,15 @@ def report(
121123
max_sample_size_embeddings: The maximum sample size for embedding calculations.
122124
statistics_path: The path of where to store the statistics to be used by `report_from_statistics`
123125
update_progress: The progress callback.
126+
random_state: Seed for the random number generators.
124127
125128
Returns:
126129
The path to the generated HTML report.
127130
Metrics instance with accuracy, similarity, and distances metrics.
128131
"""
129132

133+
set_random_state(random_state)
134+
130135
if syn_ctx_data is not None:
131136
if ctx_primary_key is None:
132137
raise ValueError("If syn_ctx_data is provided, then ctx_primary_key must also be provided.")

mostlyai/qa/reporting_from_statistics.py

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
determine_data_size,
3434
REPORT_CREDITS,
3535
ProgressCallbackWrapper,
36+
set_random_state,
3637
)
3738
from mostlyai.qa._filesystem import Statistics, TemporaryWorkspace
3839

@@ -53,6 +54,7 @@ def report_from_statistics(
5354
max_sample_size_accuracy: int | None = None,
5455
max_sample_size_coherence: int | None = None,
5556
update_progress: ProgressCallback | None = None,
57+
random_state: int | None = None,
5658
) -> Path:
5759
"""
5860
Generate an HTML report based on previously generated statistics and newly provided synthetic data samples.
@@ -70,11 +72,14 @@ def report_from_statistics(
7072
max_sample_size_accuracy: The maximum sample size for accuracy calculations.
7173
max_sample_size_coherence: The maximum sample size for coherence calculations.
7274
update_progress: The progress callback.
75+
random_state: Seed for the random number generators.
7376
7477
Returns:
7578
The path to the generated HTML report.
7679
"""
7780

81+
set_random_state(random_state)
82+
7883
with (
7984
TemporaryWorkspace() as workspace,
8085
ProgressCallbackWrapper(update_progress) as progress,

0 commit comments

Comments
 (0)