Skip to content

Commit 4204336

Browse files
authored
Merge pull request #1 from ethicalabs-ai/hf-src-dataset
Add HF datasets as a source.
2 parents 5219b3a + dc9ad4c commit 4204336

File tree

2 files changed

+110
-69
lines changed

2 files changed

+110
-69
lines changed

create_prompts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import os
3-
import random
43

54
import click
65

main.py

Lines changed: 110 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import re
44
from pathlib import Path
5-
from typing import Dict, List
5+
from typing import Dict, List, Optional
66

77
import click
88
import numpy as np
@@ -186,7 +186,7 @@ def run_experiment(self, query: str) -> Dict:
186186
"final_response": "",
187187
}
188188

189-
# Step 2: Rank responses based on critique scores
189+
# Step 2: Rank responses based on semantic similarity
190190
ranked = self.rank_responses(candidates, query)
191191

192192
# Step 3: Select the best response (highest-ranked) for recursive improvement
@@ -238,27 +238,22 @@ def load_prompt_records(prompts_file: str):
238238
"""
239239
Load prompt records from a file.
240240
241-
If prompts_file ends with '.json', it is assumed to be a JSONL file:
241+
If prompts_file ends with '.jsonl', it is assumed to be a JSONL file:
242242
each line is a JSON object containing at least a "prompt" key.
243-
e.g. {"prompt": "...", "some_other_field": "..."}
244-
245243
Otherwise, it is treated as a plain text file:
246244
each non-empty line is treated as a prompt (string).
247245
"""
248246
if prompts_file.endswith(".jsonl"):
249-
# JSONL file
250247
with open(prompts_file, "r", encoding="utf-8") as f:
251248
for line in f:
252249
line = line.strip()
253250
if not line:
254251
continue
255252
record = json.loads(line)
256-
# Ensure there's at least a 'prompt' key
257253
if "prompt" not in record:
258254
raise ValueError("JSON lines must contain a 'prompt' field.")
259255
yield record
260256
else:
261-
# Plain text: each line is a prompt
262257
with open(prompts_file, "r", encoding="utf-8") as f:
263258
for line in f:
264259
line = line.strip()
@@ -268,17 +263,26 @@ def load_prompt_records(prompts_file: str):
268263

269264

270265
def run_experiment_on_prompts(
271-
prompts_file: str,
272266
domain: str,
273267
model_name: str,
274268
critique_model_name: str,
275269
iteration_limit: int,
276-
existing_dataset: Dataset = None,
270+
existing_dataset: Optional[Dataset] = None,
277271
force: bool = False,
272+
source_dataset_file: Optional[str] = None,
273+
source_dataset_hf: Optional[str] = None,
274+
prompt_field: str = "prompt",
275+
response_field: str = "response",
276+
dataset_name: Optional[str] = None,
278277
) -> Dataset:
279278
"""
280-
Runs the Ouroboros pipeline on prompts from `prompts_file`,
281-
merges with existing_dataset if provided, and prevents duplicate records.
279+
Runs the Ouroboros pipeline on prompts.
280+
281+
If source_dataset_file or source_dataset_hf is provided, each record is assumed to contain a prompt
282+
and an existing response (using prompt_field and response_field). The experiment
283+
then refines the existing response.
284+
Otherwise, it loads prompts from prompts_file and generates responses from scratch.
285+
The final record includes the domain, source dataset info, and a dataset_name.
282286
"""
283287
ai_experiment = RecursiveAIExperiment(
284288
model_name, critique_model_name, iteration_limit
@@ -289,49 +293,53 @@ def run_experiment_on_prompts(
289293
if existing_dataset is not None and "input" in existing_dataset.column_names:
290294
existing_records = {row["input"]: row for row in existing_dataset}
291295

292-
for i, record in enumerate(load_prompt_records(prompts_file), start=1):
293-
prompt = record["prompt"]
294-
295-
# Handle duplicates:
296-
if prompt in existing_records:
297-
if not force:
298-
logging.info(f"Skipping existing prompt: {prompt}")
299-
continue
300-
else:
301-
logging.info(f"Replacing existing prompt due to --force: {prompt}")
302-
303-
# Run experiment
304-
logging.info(f"Running experiment for prompt #{i}: {prompt}")
305-
result = ai_experiment.run_experiment(prompt)
306-
reasoning_steps = extract_reasoning(result["final_response"])
307-
296+
records = []
297+
source_name = None
298+
if source_dataset_file:
299+
records = list(load_prompt_records(source_dataset_file))
300+
source_name = source_dataset_file
301+
elif source_dataset_hf:
302+
ds = load_dataset(source_dataset_hf)
303+
split = list(ds.keys())[0]
304+
records = ds[split]
305+
source_name = source_dataset_hf
306+
307+
for i, record in enumerate(records, start=1):
308+
prompt = record.get(prompt_field)
309+
original_response = record.get(response_field)
310+
if prompt is None or original_response is None:
311+
logging.warning(f"Skipping record #{i} due to missing fields.")
312+
continue
313+
314+
if prompt in existing_records and not force:
315+
logging.info(f"Skipping existing prompt: {prompt}")
316+
continue
317+
318+
logging.info(f"Refining record #{i} for prompt: {prompt}")
319+
refined_response = ai_experiment.recursive_improvement(
320+
original_response, prompt
321+
)
322+
reasoning_steps = extract_reasoning(refined_response)
308323
new_entry = {
309324
"input": prompt,
325+
"original_response": original_response,
326+
"completion": clean_response(refined_response),
310327
"reasoning": reasoning_steps if reasoning_steps else None,
311-
"completion": clean_response(result["final_response"]),
312-
"refinements": result["ranked_responses"],
313328
"domain": domain,
329+
"source_dataset": source_name,
330+
"dataset_name": dataset_name,
314331
}
315-
# Keep other keys from the record if present
316332
for k, v in record.items():
317-
if k != "prompt":
333+
if k not in {prompt_field, response_field}:
318334
new_entry[k] = v
319335

320-
# Update the existing record OR add new one
321336
existing_records[prompt] = new_entry
322337

323-
# Create updated dataset without duplicates
324338
updated_dataset = Dataset.from_list(list(existing_records.values()))
325339
return updated_dataset
326340

327341

328342
@click.command()
329-
@click.option(
330-
"--prompt_dir",
331-
type=click.Path(exists=True, file_okay=False),
332-
required=True,
333-
help="Directory containing prompt files, categorized by domain.",
334-
)
335343
@click.option(
336344
"--output_dir",
337345
type=click.Path(file_okay=False),
@@ -377,8 +385,41 @@ def run_experiment_on_prompts(
377385
is_flag=True,
378386
help="Automatically confirm overwriting files without prompting.",
379387
)
388+
@click.option(
389+
"--source_dataset",
390+
type=click.Path(exists=True),
391+
help="Path to the source dataset file to use for refinement (JSONL or text).",
392+
)
393+
@click.option(
394+
"--source_dataset_hf",
395+
type=str,
396+
help="Hugging Face dataset path to use for refinement, e.g., user/dataset_name",
397+
)
398+
@click.option(
399+
"--prompt_field",
400+
type=str,
401+
default="prompt",
402+
help="Field name in the source dataset that contains the prompt.",
403+
)
404+
@click.option(
405+
"--response_field",
406+
type=str,
407+
default="response",
408+
help="Field name in the source dataset that contains the response.",
409+
)
410+
@click.option(
411+
"--dataset_name",
412+
type=str,
413+
required=True,
414+
help="Name to assign to the final dataset.",
415+
)
416+
@click.option(
417+
"--domain",
418+
type=str,
419+
default=None,
420+
help="Set the domain for the dataset.",
421+
)
380422
def main(
381-
prompt_dir,
382423
output_dir,
383424
hf_dataset,
384425
model_name,
@@ -387,64 +428,66 @@ def main(
387428
force,
388429
push_to_hf,
389430
yes,
431+
source_dataset,
432+
source_dataset_hf,
433+
prompt_field,
434+
response_field,
435+
dataset_name,
436+
domain,
390437
):
391438
logging.basicConfig(level=logging.INFO)
392-
prompt_path = Path(prompt_dir)
393439
output_path = Path(output_dir)
394440
output_path.mkdir(parents=True, exist_ok=True)
395441

396-
# 1) Load existing dataset from HF (if specified)
397442
existing_dataset = None
398443
if hf_dataset:
399444
logging.info(f"Loading dataset from Hugging Face: {hf_dataset}")
400445
try:
401446
ds_dict = load_dataset(hf_dataset)
402-
first_split = list(ds_dict.keys())[0] # e.g. "train"
447+
first_split = list(ds_dict.keys())[0]
403448
existing_dataset = ds_dict[first_split]
404449
logging.info(f"Loaded '{hf_dataset}' with {len(existing_dataset)} rows.")
405450
except Exception as e:
406451
logging.warning(f"Failed to load dataset: {e}")
407452
existing_dataset = None
408453

409-
# 2) Iterate prompt files to accumulate updates in a single dataset
410454
merged_dataset = existing_dataset
411455
changes_detected = False
412456

413-
for file in prompt_path.glob("*.*"):
414-
domain = file.stem
415-
logging.info(f"Processing domain: {domain} from file: {file}")
457+
if source_dataset or source_dataset_hf:
458+
current_domain = domain if domain else "default"
459+
logging.info(
460+
f"Processing source dataset: {source_dataset or source_dataset_hf} with domain: {current_domain}"
461+
)
416462
updated_dataset = run_experiment_on_prompts(
417-
prompts_file=str(file),
418-
domain=domain,
463+
domain=current_domain,
419464
model_name=model_name,
420465
critique_model_name=critique_model_name,
421466
iteration_limit=num_iterations,
422467
existing_dataset=merged_dataset,
423468
force=force,
469+
source_dataset_file=source_dataset,
470+
source_dataset_hf=source_dataset_hf,
471+
prompt_field=prompt_field,
472+
response_field=response_field,
473+
dataset_name=dataset_name,
424474
)
425-
# If updated dataset is bigger => new data was added
426-
if (
427-
force
428-
or merged_dataset is None
429-
or len(updated_dataset) > len(merged_dataset)
430-
):
431-
changes_detected = True
432-
merged_dataset = updated_dataset # Keep the newly updated dataset
433-
else:
434-
logging.info(f"No new prompts were added for domain: {domain}.")
475+
changes_detected = True
476+
merged_dataset = updated_dataset
435477

436-
# 3) If changes were detected, save locally
437478
if changes_detected and merged_dataset is not None:
438-
# Just pick a single name or store separate domain files if needed
439479
dataset_path_parquet = output_path / "ouroboros_dataset.parquet"
440480
dataset_path_json = output_path / "ouroboros_dataset.json"
441481

442-
if dataset_path_parquet.exists():
443-
if not yes and click.confirm(
482+
if (
483+
dataset_path_parquet.exists()
484+
and not yes
485+
and not click.confirm(
444486
f"{dataset_path_parquet} exists. Overwrite?", default=True
445-
):
446-
logging.info("Skipping save due to user cancel.")
447-
return
487+
)
488+
):
489+
logging.info("User cancelled overwrite. Exiting.")
490+
return
448491

449492
merged_dataset.to_parquet(str(dataset_path_parquet))
450493
merged_dataset.to_json(str(dataset_path_json))
@@ -454,7 +497,6 @@ def main(
454497
else:
455498
logging.info("No changes detected. Skipping local save.")
456499

457-
# 4) Push once to Hugging Face if requested
458500
if push_to_hf and hf_dataset and changes_detected and merged_dataset is not None:
459501
token = HfFolder.get_token()
460502
if not token:

0 commit comments

Comments
 (0)