Skip to content

Commit bb8b910

Browse files
authored
fix: fixed issue with RANDOM sampling of rows in case of sequential data (#120)
1 parent b4a0394 commit bb8b910

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

mostlyai/qa/_sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def sample_two_consecutive_rows(
152152
seq_lens = df.groupby(col_by).size()
153153

154154
# make random draw from [0, seq_len-1]
155-
sel_idx = (seq_lens - 1) * np.random.random(len(seq_lens)).astype("int")
155+
sel_idx = ((seq_lens - 1) * np.random.random(len(seq_lens))).astype("int")
156156
sel_idx_df = pd.Series(sel_idx).to_frame("__IDX").reset_index()
157157

158158
# filter to randomly selected indices

tests/unit/test_accuracy.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,14 @@ def test_plot_univariate_distribution_numeric():
340340

341341

342342
def test_sample_two_consecutive_rows():
343-
df = pd.DataFrame({"id": [1, 1, 1, 1, 2, 2, 2, 3, 3, 4], "x": [1, 2, 3, 5, 1, 2, 3, 1, 2, 1]})
343+
df = pd.DataFrame(
344+
{
345+
"id": [1] * 1000 + [2] * 500 + [3] * 2 + [4] * 1,
346+
"x": list(range(1000)) + list(range(500)) + list(range(2)) + list(range(1)),
347+
}
348+
)
344349
first_rows, second_rows = sample_two_consecutive_rows(df=df, col_by="id")
350+
assert not (first_rows["x"] == 0).all()
345351
assert len(first_rows) == 4
346352
assert len(second_rows) == 3
347353
assert (first_rows["x"][0:2] == second_rows["x"][0:2] - 1).all()

0 commit comments

Comments
 (0)