@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
345
345
size_t last_sym_start = rule.size ();
346
346
const char * pos = src;
347
347
348
- auto handle_repetitions = [&](int min_times, int max_times) {
348
+ auto handle_repetitions = [&](int min_times, int max_times) {
349
349
350
- if (last_sym_start == rule.size ()) {
351
- throw std::runtime_error (std::string (" expecting preceding item to */+/?/{ at " ) + pos);
352
- }
350
+ if (last_sym_start == rule.size ()) {
351
+ throw std::runtime_error (std::string (" expecting preceding item to */+/?/{ at " ) + pos);
352
+ }
353
353
354
- // apply transformation to previous symbol (last_sym_start to end) according to
355
- // the following rewrite rules:
356
- // S{m,n} --> S S S (m times) S'(n-m)
357
- // S'(x) ::= S S'(x-1) |
358
- // (... n-m definitions of these S' rules ...)
359
- // S'(1) ::= S |
360
- // S{m,} --> S S S (m times) S'
361
- // S' ::= S S' |
362
- // S* --> S{0,}
363
- // --> S' ::= S S' |
364
- // S+ --> S{1,}
365
- // --> S S'
366
- // S' ::= S S' |
367
- // S? --> S{0,1}
368
- // --> S'
369
- // S' ::= S |
370
-
371
- llama_grammar_rule prev_rule (rule.begin () + last_sym_start, rule.end ());
372
- if (min_times == 0 ) {
373
- rule.resize (last_sym_start);
374
- } else {
375
- // Repeat the previous elements (min_times - 1) times
376
- for (int i = 1 ; i < min_times; i++) {
377
- rule.insert (rule.end (), prev_rule.begin (), prev_rule.end ());
378
- }
354
+ // apply transformation to previous symbol (last_sym_start to end) according to
355
+ // the following rewrite rules:
356
+ // S{m,n} --> S S S (m times) S'(n-m)
357
+ // S'(x) ::= S S'(x-1) |
358
+ // (... n-m definitions of these S' rules ...)
359
+ // S'(1) ::= S |
360
+ // S{m,} --> S S S (m times) S'
361
+ // S' ::= S S' |
362
+ // S* --> S{0,}
363
+ // --> S' ::= S S' |
364
+ // S+ --> S{1,}
365
+ // --> S S'
366
+ // S' ::= S S' |
367
+ // S? --> S{0,1}
368
+ // --> S'
369
+ // S' ::= S |
370
+
371
+ llama_grammar_rule prev_rule (rule.begin () + last_sym_start, rule.end ());
372
+ if (min_times == 0 ) {
373
+ rule.resize (last_sym_start);
374
+ } else {
375
+ // Repeat the previous elements (min_times - 1) times
376
+ for (int i = 1 ; i < min_times; i++) {
377
+ rule.insert (rule.end (), prev_rule.begin (), prev_rule.end ());
379
378
}
379
+ }
380
380
381
- uint32_t last_rec_rule_id = 0 ;
382
- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
381
+ uint32_t last_rec_rule_id = 0 ;
382
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
383
383
384
- llama_grammar_rule rec_rule (prev_rule);
385
- for (int i = 0 ; i < n_opt; i++) {
386
- rec_rule.resize (prev_rule.size ());
387
- uint32_t rec_rule_id = generate_symbol_id ( rule_name);
388
- if (i > 0 || max_times < 0 ) {
389
- rec_rule.push_back ({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
390
- }
391
- rec_rule.push_back ({LLAMA_GRETYPE_ALT, 0 });
392
- rec_rule.push_back ({LLAMA_GRETYPE_END, 0 });
393
- add_rule ( rec_rule_id, rec_rule);
394
- last_rec_rule_id = rec_rule_id;
395
- }
396
- if (n_opt > 0 ) {
397
- rule.push_back ({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
384
+ llama_grammar_rule rec_rule (prev_rule);
385
+ for (int i = 0 ; i < n_opt; i++) {
386
+ rec_rule.resize (prev_rule.size ());
387
+ uint32_t rec_rule_id = generate_symbol_id ( rule_name);
388
+ if (i > 0 || max_times < 0 ) {
389
+ rec_rule.push_back ({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
398
390
}
399
- };
391
+ rec_rule.push_back ({LLAMA_GRETYPE_ALT, 0 });
392
+ rec_rule.push_back ({LLAMA_GRETYPE_END, 0 });
393
+ add_rule ( rec_rule_id, rec_rule);
394
+ last_rec_rule_id = rec_rule_id;
395
+ }
396
+ if (n_opt > 0 ) {
397
+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
398
+ }
399
+ };
400
400
401
- while (*pos) {
402
- if (*pos == ' "' ) { // literal string
403
- pos++;
404
- last_sym_start = rule.size ();
405
- while (*pos != ' "' ) {
406
- if (!*pos) {
407
- throw std::runtime_error (" unexpected end of input" );
408
- }
409
- auto char_pair = parse_char (pos);
410
- pos = char_pair.second ;
411
- rule.push_back ({LLAMA_GRETYPE_CHAR, char_pair.first });
401
+ while (*pos) {
402
+ if (*pos == ' "' ) { // literal string
403
+ pos++;
404
+ last_sym_start = rule.size ();
405
+ while (*pos != ' "' ) {
406
+ if (!*pos) {
407
+ throw std::runtime_error (" unexpected end of input" );
412
408
}
413
- pos = parse_space (pos + 1 , is_nested);
414
- } else if (*pos == ' [' ) { // char range(s)
409
+ auto char_pair = parse_char (pos);
410
+ pos = char_pair.second ;
411
+ rule.push_back ({LLAMA_GRETYPE_CHAR, char_pair.first });
412
+ }
413
+ pos = parse_space (pos + 1 , is_nested);
414
+ } else if (*pos == ' [' ) { // char range(s)
415
+ pos++;
416
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417
+ if (*pos == ' ^' ) {
415
418
pos++;
416
- enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
417
- if (*pos == ' ^' ) {
418
- pos++;
419
- start_type = LLAMA_GRETYPE_CHAR_NOT;
419
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
420
+ }
421
+ last_sym_start = rule.size ();
422
+ while (*pos != ' ]' ) {
423
+ if (!*pos) {
424
+ throw std::runtime_error (" unexpected end of input" );
420
425
}
421
- last_sym_start = rule.size ();
422
- while (*pos != ' ]' ) {
423
- if (!*pos) {
426
+ auto char_pair = parse_char (pos);
427
+ pos = char_pair.second ;
428
+ enum llama_gretype type = last_sym_start < rule.size ()
429
+ ? LLAMA_GRETYPE_CHAR_ALT
430
+ : start_type;
431
+
432
+ rule.push_back ({type, char_pair.first });
433
+ if (pos[0 ] == ' -' && pos[1 ] != ' ]' ) {
434
+ if (!pos[1 ]) {
424
435
throw std::runtime_error (" unexpected end of input" );
425
436
}
426
- auto char_pair = parse_char (pos);
427
- pos = char_pair.second ;
428
- enum llama_gretype type = last_sym_start < rule.size ()
429
- ? LLAMA_GRETYPE_CHAR_ALT
430
- : start_type;
431
-
432
- rule.push_back ({type, char_pair.first });
433
- if (pos[0 ] == ' -' && pos[1 ] != ' ]' ) {
434
- if (!pos[1 ]) {
435
- throw std::runtime_error (" unexpected end of input" );
436
- }
437
- auto endchar_pair = parse_char (pos + 1 );
438
- pos = endchar_pair.second ;
439
- rule.push_back ({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first });
440
- }
441
- }
442
- pos = parse_space (pos + 1 , is_nested);
443
- } else if (is_word_char (*pos)) { // rule reference
444
- const char * name_end = parse_name (pos);
445
- uint32_t ref_rule_id = get_symbol_id (pos, name_end - pos);
446
- pos = parse_space (name_end, is_nested);
447
- last_sym_start = rule.size ();
448
- rule.push_back ({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449
- } else if (*pos == ' (' ) { // grouping
450
- // parse nested alternates into synthesized rule
451
- pos = parse_space (pos + 1 , true );
452
- uint32_t sub_rule_id = generate_symbol_id (rule_name);
453
- pos = parse_alternates (pos, rule_name, sub_rule_id, true );
454
- last_sym_start = rule.size ();
455
- // output reference to synthesized rule
456
- rule.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457
- if (*pos != ' )' ) {
458
- throw std::runtime_error (std::string (" expecting ')' at " ) + pos);
437
+ auto endchar_pair = parse_char (pos + 1 );
438
+ pos = endchar_pair.second ;
439
+ rule.push_back ({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first });
459
440
}
441
+ }
442
+ pos = parse_space (pos + 1 , is_nested);
443
+ } else if (is_word_char (*pos)) { // rule reference
444
+ const char * name_end = parse_name (pos);
445
+ uint32_t ref_rule_id = get_symbol_id (pos, name_end - pos);
446
+ pos = parse_space (name_end, is_nested);
447
+ last_sym_start = rule.size ();
448
+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
449
+ } else if (*pos == ' (' ) { // grouping
450
+ // parse nested alternates into synthesized rule
451
+ pos = parse_space (pos + 1 , true );
452
+ uint32_t sub_rule_id = generate_symbol_id (rule_name);
453
+ pos = parse_alternates (pos, rule_name, sub_rule_id, true );
454
+ last_sym_start = rule.size ();
455
+ // output reference to synthesized rule
456
+ rule.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
457
+ if (*pos != ' )' ) {
458
+ throw std::runtime_error (std::string (" expecting ')' at " ) + pos);
459
+ }
460
+ pos = parse_space (pos + 1 , is_nested);
461
+ } else if (*pos == ' .' ) { // any char
462
+ last_sym_start = rule.size ();
463
+ rule.push_back ({LLAMA_GRETYPE_CHAR_ANY, 0 });
464
+ pos = parse_space (pos + 1 , is_nested);
465
+ } else if (*pos == ' *' ) {
466
+ pos = parse_space (pos + 1 , is_nested);
467
+ handle_repetitions (0 , -1 );
468
+ } else if (*pos == ' +' ) {
469
+ pos = parse_space (pos + 1 , is_nested);
470
+ handle_repetitions (1 , -1 );
471
+ } else if (*pos == ' ?' ) {
472
+ pos = parse_space (pos + 1 , is_nested);
473
+ handle_repetitions (0 , 1 );
474
+ } else if (*pos == ' {' ) {
475
+ pos = parse_space (pos + 1 , is_nested);
476
+
477
+ if (!is_digit_char (*pos)) {
478
+ throw std::runtime_error (std::string (" expecting an int at " ) + pos);
479
+ }
480
+ const char * int_end = parse_int (pos);
481
+ int min_times = std::stoul (std::string (pos, int_end - pos));
482
+ pos = parse_space (int_end, is_nested);
483
+
484
+ int max_times = -1 ;
485
+
486
+ if (*pos == ' }' ) {
487
+ max_times = min_times;
460
488
pos = parse_space (pos + 1 , is_nested);
461
- } else if (*pos == ' .' ) { // any char
462
- last_sym_start = rule.size ();
463
- rule.push_back ({LLAMA_GRETYPE_CHAR_ANY, 0 });
464
- pos = parse_space (pos + 1 , is_nested);
465
- } else if (*pos == ' *' ) {
466
- pos = parse_space (pos + 1 , is_nested);
467
- handle_repetitions (0 , -1 );
468
- } else if (*pos == ' +' ) {
469
- pos = parse_space (pos + 1 , is_nested);
470
- handle_repetitions (1 , -1 );
471
- } else if (*pos == ' ?' ) {
472
- pos = parse_space (pos + 1 , is_nested);
473
- handle_repetitions (0 , 1 );
474
- } else if (*pos == ' {' ) {
489
+ } else if (*pos == ' ,' ) {
475
490
pos = parse_space (pos + 1 , is_nested);
476
491
477
- if (!is_digit_char (*pos)) {
478
- throw std::runtime_error (std::string (" expecting an int at " ) + pos);
492
+ if (is_digit_char (*pos)) {
493
+ const char * int_end = parse_int (pos);
494
+ max_times = std::stoul (std::string (pos, int_end - pos));
495
+ pos = parse_space (int_end, is_nested);
479
496
}
480
- const char * int_end = parse_int (pos);
481
- int min_times = std::stoul (std::string (pos, int_end - pos));
482
- pos = parse_space (int_end, is_nested);
483
-
484
- int max_times = -1 ;
485
-
486
- if (*pos == ' }' ) {
487
- max_times = min_times;
488
- pos = parse_space (pos + 1 , is_nested);
489
- } else if (*pos == ' ,' ) {
490
- pos = parse_space (pos + 1 , is_nested);
491
-
492
- if (is_digit_char (*pos)) {
493
- const char * int_end = parse_int (pos);
494
- max_times = std::stoul (std::string (pos, int_end - pos));
495
- pos = parse_space (int_end, is_nested);
496
- }
497
497
498
- if (*pos != ' }' ) {
499
- throw std::runtime_error (std::string (" expecting '}' at " ) + pos);
500
- }
501
- pos = parse_space (pos + 1 , is_nested);
502
- } else {
503
- throw std::runtime_error (std::string (" expecting ',' at " ) + pos);
498
+ if (*pos != ' }' ) {
499
+ throw std::runtime_error (std::string (" expecting '}' at " ) + pos);
504
500
}
505
- handle_repetitions (min_times, max_times );
501
+ pos = parse_space (pos + 1 , is_nested );
506
502
} else {
507
- break ;
503
+ throw std::runtime_error ( std::string ( " expecting ',' at " ) + pos) ;
508
504
}
505
+ handle_repetitions (min_times, max_times);
506
+ } else {
507
+ break ;
509
508
}
510
- return pos;
511
509
}
510
+ return pos;
511
+ }
512
512
513
513
const char * llama_grammar_parser::parse_rule (const char * src) {
514
- const char * name_end = parse_name (src);
515
- const char * pos = parse_space (name_end, false );
516
- size_t name_len = name_end - src;
517
- uint32_t rule_id = get_symbol_id (src, name_len);
518
- const std::string name (src, name_len);
519
-
520
- if (!(pos[0 ] == ' :' && pos[1 ] == ' :' && pos[2 ] == ' =' )) {
521
- throw std::runtime_error (std::string (" expecting ::= at " ) + pos);
522
- }
523
- pos = parse_space (pos + 3 , true );
514
+ const char * name_end = parse_name (src);
515
+ const char * pos = parse_space (name_end, false );
516
+ size_t name_len = name_end - src;
517
+ uint32_t rule_id = get_symbol_id (src, name_len);
518
+ const std::string name (src, name_len);
519
+
520
+ if (!(pos[0 ] == ' :' && pos[1 ] == ' :' && pos[2 ] == ' =' )) {
521
+ throw std::runtime_error (std::string (" expecting ::= at " ) + pos);
522
+ }
523
+ pos = parse_space (pos + 3 , true );
524
524
525
- pos = parse_alternates (pos, name, rule_id, false );
525
+ pos = parse_alternates (pos, name, rule_id, false );
526
526
527
- if (*pos == ' \r ' ) {
528
- pos += pos[1 ] == ' \n ' ? 2 : 1 ;
529
- } else if (*pos == ' \n ' ) {
530
- pos++;
531
- } else if (*pos) {
532
- throw std::runtime_error (std::string (" expecting newline or end at " ) + pos);
533
- }
534
- return parse_space (pos, true );
527
+ if (*pos == ' \r ' ) {
528
+ pos += pos[1 ] == ' \n ' ? 2 : 1 ;
529
+ } else if (*pos == ' \n ' ) {
530
+ pos++;
531
+ } else if (*pos) {
532
+ throw std::runtime_error (std::string (" expecting newline or end at " ) + pos);
535
533
}
534
+ return parse_space (pos, true );
535
+ }
536
536
537
537
bool llama_grammar_parser::parse (const char * src) {
538
538
try {
0 commit comments