2
2
import logging
3
3
import re
4
4
from pathlib import Path
5
- from typing import Dict , List
5
+ from typing import Dict , List , Optional
6
6
7
7
import click
8
8
import numpy as np
@@ -186,7 +186,7 @@ def run_experiment(self, query: str) -> Dict:
186
186
"final_response" : "" ,
187
187
}
188
188
189
- # Step 2: Rank responses based on critique scores
189
+ # Step 2: Rank responses based on semantic similarity
190
190
ranked = self .rank_responses (candidates , query )
191
191
192
192
# Step 3: Select the best response (highest-ranked) for recursive improvement
@@ -238,27 +238,22 @@ def load_prompt_records(prompts_file: str):
238
238
"""
239
239
Load prompt records from a file.
240
240
241
- If prompts_file ends with '.json ', it is assumed to be a JSONL file:
241
+ If prompts_file ends with '.jsonl ', it is assumed to be a JSONL file:
242
242
each line is a JSON object containing at least a "prompt" key.
243
- e.g. {"prompt": "...", "some_other_field": "..."}
244
-
245
243
Otherwise, it is treated as a plain text file:
246
244
each non-empty line is treated as a prompt (string).
247
245
"""
248
246
if prompts_file .endswith (".jsonl" ):
249
- # JSONL file
250
247
with open (prompts_file , "r" , encoding = "utf-8" ) as f :
251
248
for line in f :
252
249
line = line .strip ()
253
250
if not line :
254
251
continue
255
252
record = json .loads (line )
256
- # Ensure there's at least a 'prompt' key
257
253
if "prompt" not in record :
258
254
raise ValueError ("JSON lines must contain a 'prompt' field." )
259
255
yield record
260
256
else :
261
- # Plain text: each line is a prompt
262
257
with open (prompts_file , "r" , encoding = "utf-8" ) as f :
263
258
for line in f :
264
259
line = line .strip ()
@@ -268,17 +263,26 @@ def load_prompt_records(prompts_file: str):
268
263
269
264
270
265
def run_experiment_on_prompts (
271
- prompts_file : str ,
272
266
domain : str ,
273
267
model_name : str ,
274
268
critique_model_name : str ,
275
269
iteration_limit : int ,
276
- existing_dataset : Dataset = None ,
270
+ existing_dataset : Optional [ Dataset ] = None ,
277
271
force : bool = False ,
272
+ source_dataset_file : Optional [str ] = None ,
273
+ source_dataset_hf : Optional [str ] = None ,
274
+ prompt_field : str = "prompt" ,
275
+ response_field : str = "response" ,
276
+ dataset_name : Optional [str ] = None ,
278
277
) -> Dataset :
279
278
"""
280
- Runs the Ouroboros pipeline on prompts from `prompts_file`,
281
- merges with existing_dataset if provided, and prevents duplicate records.
279
+ Runs the Ouroboros pipeline on prompts.
280
+
281
+ If source_dataset_file or source_dataset_hf is provided, each record is assumed to contain a prompt
282
+ and an existing response (using prompt_field and response_field). The experiment
283
+ then refines the existing response.
284
+ Otherwise, it loads prompts from prompts_file and generates responses from scratch.
285
+ The final record includes the domain, source dataset info, and a dataset_name.
282
286
"""
283
287
ai_experiment = RecursiveAIExperiment (
284
288
model_name , critique_model_name , iteration_limit
@@ -289,49 +293,53 @@ def run_experiment_on_prompts(
289
293
if existing_dataset is not None and "input" in existing_dataset .column_names :
290
294
existing_records = {row ["input" ]: row for row in existing_dataset }
291
295
292
- for i , record in enumerate (load_prompt_records (prompts_file ), start = 1 ):
293
- prompt = record ["prompt" ]
294
-
295
- # Handle duplicates:
296
- if prompt in existing_records :
297
- if not force :
298
- logging .info (f"Skipping existing prompt: { prompt } " )
299
- continue
300
- else :
301
- logging .info (f"Replacing existing prompt due to --force: { prompt } " )
302
-
303
- # Run experiment
304
- logging .info (f"Running experiment for prompt #{ i } : { prompt } " )
305
- result = ai_experiment .run_experiment (prompt )
306
- reasoning_steps = extract_reasoning (result ["final_response" ])
307
-
296
+ records = []
297
+ source_name = None
298
+ if source_dataset_file :
299
+ records = list (load_prompt_records (source_dataset_file ))
300
+ source_name = source_dataset_file
301
+ elif source_dataset_hf :
302
+ ds = load_dataset (source_dataset_hf )
303
+ split = list (ds .keys ())[0 ]
304
+ records = ds [split ]
305
+ source_name = source_dataset_hf
306
+
307
+ for i , record in enumerate (records , start = 1 ):
308
+ prompt = record .get (prompt_field )
309
+ original_response = record .get (response_field )
310
+ if prompt is None or original_response is None :
311
+ logging .warning (f"Skipping record #{ i } due to missing fields." )
312
+ continue
313
+
314
+ if prompt in existing_records and not force :
315
+ logging .info (f"Skipping existing prompt: { prompt } " )
316
+ continue
317
+
318
+ logging .info (f"Refining record #{ i } for prompt: { prompt } " )
319
+ refined_response = ai_experiment .recursive_improvement (
320
+ original_response , prompt
321
+ )
322
+ reasoning_steps = extract_reasoning (refined_response )
308
323
new_entry = {
309
324
"input" : prompt ,
325
+ "original_response" : original_response ,
326
+ "completion" : clean_response (refined_response ),
310
327
"reasoning" : reasoning_steps if reasoning_steps else None ,
311
- "completion" : clean_response (result ["final_response" ]),
312
- "refinements" : result ["ranked_responses" ],
313
328
"domain" : domain ,
329
+ "source_dataset" : source_name ,
330
+ "dataset_name" : dataset_name ,
314
331
}
315
- # Keep other keys from the record if present
316
332
for k , v in record .items ():
317
- if k != "prompt" :
333
+ if k not in { prompt_field , response_field } :
318
334
new_entry [k ] = v
319
335
320
- # Update the existing record OR add new one
321
336
existing_records [prompt ] = new_entry
322
337
323
- # Create updated dataset without duplicates
324
338
updated_dataset = Dataset .from_list (list (existing_records .values ()))
325
339
return updated_dataset
326
340
327
341
328
342
@click .command ()
329
- @click .option (
330
- "--prompt_dir" ,
331
- type = click .Path (exists = True , file_okay = False ),
332
- required = True ,
333
- help = "Directory containing prompt files, categorized by domain." ,
334
- )
335
343
@click .option (
336
344
"--output_dir" ,
337
345
type = click .Path (file_okay = False ),
@@ -377,8 +385,41 @@ def run_experiment_on_prompts(
377
385
is_flag = True ,
378
386
help = "Automatically confirm overwriting files without prompting." ,
379
387
)
388
+ @click .option (
389
+ "--source_dataset" ,
390
+ type = click .Path (exists = True ),
391
+ help = "Path to the source dataset file to use for refinement (JSONL or text)." ,
392
+ )
393
+ @click .option (
394
+ "--source_dataset_hf" ,
395
+ type = str ,
396
+ help = "Hugging Face dataset path to use for refinement, e.g., user/dataset_name" ,
397
+ )
398
+ @click .option (
399
+ "--prompt_field" ,
400
+ type = str ,
401
+ default = "prompt" ,
402
+ help = "Field name in the source dataset that contains the prompt." ,
403
+ )
404
+ @click .option (
405
+ "--response_field" ,
406
+ type = str ,
407
+ default = "response" ,
408
+ help = "Field name in the source dataset that contains the response." ,
409
+ )
410
+ @click .option (
411
+ "--dataset_name" ,
412
+ type = str ,
413
+ required = True ,
414
+ help = "Name to assign to the final dataset." ,
415
+ )
416
+ @click .option (
417
+ "--domain" ,
418
+ type = str ,
419
+ default = None ,
420
+ help = "Set the domain for the dataset." ,
421
+ )
380
422
def main (
381
- prompt_dir ,
382
423
output_dir ,
383
424
hf_dataset ,
384
425
model_name ,
@@ -387,64 +428,66 @@ def main(
387
428
force ,
388
429
push_to_hf ,
389
430
yes ,
431
+ source_dataset ,
432
+ source_dataset_hf ,
433
+ prompt_field ,
434
+ response_field ,
435
+ dataset_name ,
436
+ domain ,
390
437
):
391
438
logging .basicConfig (level = logging .INFO )
392
- prompt_path = Path (prompt_dir )
393
439
output_path = Path (output_dir )
394
440
output_path .mkdir (parents = True , exist_ok = True )
395
441
396
- # 1) Load existing dataset from HF (if specified)
397
442
existing_dataset = None
398
443
if hf_dataset :
399
444
logging .info (f"Loading dataset from Hugging Face: { hf_dataset } " )
400
445
try :
401
446
ds_dict = load_dataset (hf_dataset )
402
- first_split = list (ds_dict .keys ())[0 ] # e.g. "train"
447
+ first_split = list (ds_dict .keys ())[0 ]
403
448
existing_dataset = ds_dict [first_split ]
404
449
logging .info (f"Loaded '{ hf_dataset } ' with { len (existing_dataset )} rows." )
405
450
except Exception as e :
406
451
logging .warning (f"Failed to load dataset: { e } " )
407
452
existing_dataset = None
408
453
409
- # 2) Iterate prompt files to accumulate updates in a single dataset
410
454
merged_dataset = existing_dataset
411
455
changes_detected = False
412
456
413
- for file in prompt_path .glob ("*.*" ):
414
- domain = file .stem
415
- logging .info (f"Processing domain: { domain } from file: { file } " )
457
+ if source_dataset or source_dataset_hf :
458
+ current_domain = domain if domain else "default"
459
+ logging .info (
460
+ f"Processing source dataset: { source_dataset or source_dataset_hf } with domain: { current_domain } "
461
+ )
416
462
updated_dataset = run_experiment_on_prompts (
417
- prompts_file = str (file ),
418
- domain = domain ,
463
+ domain = current_domain ,
419
464
model_name = model_name ,
420
465
critique_model_name = critique_model_name ,
421
466
iteration_limit = num_iterations ,
422
467
existing_dataset = merged_dataset ,
423
468
force = force ,
469
+ source_dataset_file = source_dataset ,
470
+ source_dataset_hf = source_dataset_hf ,
471
+ prompt_field = prompt_field ,
472
+ response_field = response_field ,
473
+ dataset_name = dataset_name ,
424
474
)
425
- # If updated dataset is bigger => new data was added
426
- if (
427
- force
428
- or merged_dataset is None
429
- or len (updated_dataset ) > len (merged_dataset )
430
- ):
431
- changes_detected = True
432
- merged_dataset = updated_dataset # Keep the newly updated dataset
433
- else :
434
- logging .info (f"No new prompts were added for domain: { domain } ." )
475
+ changes_detected = True
476
+ merged_dataset = updated_dataset
435
477
436
- # 3) If changes were detected, save locally
437
478
if changes_detected and merged_dataset is not None :
438
- # Just pick a single name or store separate domain files if needed
439
479
dataset_path_parquet = output_path / "ouroboros_dataset.parquet"
440
480
dataset_path_json = output_path / "ouroboros_dataset.json"
441
481
442
- if dataset_path_parquet .exists ():
443
- if not yes and click .confirm (
482
+ if (
483
+ dataset_path_parquet .exists ()
484
+ and not yes
485
+ and not click .confirm (
444
486
f"{ dataset_path_parquet } exists. Overwrite?" , default = True
445
- ):
446
- logging .info ("Skipping save due to user cancel." )
447
- return
487
+ )
488
+ ):
489
+ logging .info ("User cancelled overwrite. Exiting." )
490
+ return
448
491
449
492
merged_dataset .to_parquet (str (dataset_path_parquet ))
450
493
merged_dataset .to_json (str (dataset_path_json ))
@@ -454,7 +497,6 @@ def main(
454
497
else :
455
498
logging .info ("No changes detected. Skipping local save." )
456
499
457
- # 4) Push once to Hugging Face if requested
458
500
if push_to_hf and hf_dataset and changes_detected and merged_dataset is not None :
459
501
token = HfFolder .get_token ()
460
502
if not token :
0 commit comments