File tree 3 files changed +30
-0
lines changed
3 files changed +30
-0
lines changed Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
15
15
import logging
16
+ import os
17
+ import struct
16
18
from typing import Protocol
17
19
18
20
import pandas as pd
@@ -122,3 +124,21 @@ def determine_data_size(
122
124
return len (tgt_keys )
123
125
else :
124
126
return len (tgt_data )
127
+
128
+
129
+ def set_random_state (random_state : int | None = None ):
130
+ def get_random_int_from_os () -> int :
131
+ # 32-bit, cryptographically secure random int from os
132
+ return int (struct .unpack ("I" , os .urandom (4 ))[0 ])
133
+
134
+ if random_state is not None :
135
+ _LOG .info (f"Global random_state set to `{ random_state } `" )
136
+
137
+ if random_state is None :
138
+ random_state = get_random_int_from_os ()
139
+
140
+ import random
141
+ import numpy as np
142
+
143
+ random .seed (random_state )
144
+ np .random .seed (random_state )
Original file line number Diff line number Diff line change 62
62
TGT_COLUMN_PREFIX ,
63
63
REPORT_CREDITS ,
64
64
ProgressCallbackWrapper ,
65
+ set_random_state ,
65
66
)
66
67
from mostlyai .qa ._filesystem import Statistics , TemporaryWorkspace
67
68
@@ -87,6 +88,7 @@ def report(
87
88
max_sample_size_embeddings : int | None = None ,
88
89
statistics_path : str | Path | None = None ,
89
90
update_progress : ProgressCallback | None = None ,
91
+ random_state : int | None = None ,
90
92
) -> tuple [Path , ModelMetrics | None ]:
91
93
"""
92
94
Generate an HTML report and metrics for assessing synthetic data quality.
@@ -121,12 +123,15 @@ def report(
121
123
max_sample_size_embeddings: The maximum sample size for embedding calculations.
122
124
statistics_path: The path of where to store the statistics to be used by `report_from_statistics`
123
125
update_progress: The progress callback.
126
+ random_state: Seed for the random number generators.
124
127
125
128
Returns:
126
129
The path to the generated HTML report.
127
130
Metrics instance with accuracy, similarity, and distances metrics.
128
131
"""
129
132
133
+ set_random_state (random_state )
134
+
130
135
if syn_ctx_data is not None :
131
136
if ctx_primary_key is None :
132
137
raise ValueError ("If syn_ctx_data is provided, then ctx_primary_key must also be provided." )
Original file line number Diff line number Diff line change 33
33
determine_data_size ,
34
34
REPORT_CREDITS ,
35
35
ProgressCallbackWrapper ,
36
+ set_random_state ,
36
37
)
37
38
from mostlyai .qa ._filesystem import Statistics , TemporaryWorkspace
38
39
@@ -53,6 +54,7 @@ def report_from_statistics(
53
54
max_sample_size_accuracy : int | None = None ,
54
55
max_sample_size_coherence : int | None = None ,
55
56
update_progress : ProgressCallback | None = None ,
57
+ random_state : int | None = None ,
56
58
) -> Path :
57
59
"""
58
60
Generate an HTML report based on previously generated statistics and newly provided synthetic data samples.
@@ -70,11 +72,14 @@ def report_from_statistics(
70
72
max_sample_size_accuracy: The maximum sample size for accuracy calculations.
71
73
max_sample_size_coherence: The maximum sample size for coherence calculations.
72
74
update_progress: The progress callback.
75
+ random_state: Seed for the random number generators.
73
76
74
77
Returns:
75
78
The path to the generated HTML report.
76
79
"""
77
80
81
+ set_random_state (random_state )
82
+
78
83
with (
79
84
TemporaryWorkspace () as workspace ,
80
85
ProgressCallbackWrapper (update_progress ) as progress ,
You can’t perform that action at this time.
0 commit comments