Skip to content

Commit e15e335

Browse files
Config option to allow custom parameter syntax (#78)
Co-authored-by: Matthew Peveler <[email protected]>
1 parent c17b973 commit e15e335

File tree

7 files changed

+550
-61
lines changed

7 files changed

+550
-61
lines changed

src/defines.ts

+10
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,20 @@ export type StatementType =
7676

7777
export type ExecutionType = 'LISTING' | 'MODIFICATION' | 'INFORMATION' | 'ANON_BLOCK' | 'UNKNOWN';
7878

79+
export interface ParamTypes {
80+
positional?: boolean;
81+
numbered?: ('?' | ':' | '$')[];
82+
named?: (':' | '@' | '$')[];
83+
quoted?: (':' | '@' | '$')[];
84+
// regex for identifying that it is a param
85+
custom?: string[];
86+
}
87+
7988
export interface IdentifyOptions {
8089
strict?: boolean;
8190
dialect?: Dialect;
8291
identifyTables?: boolean;
92+
paramTypes?: ParamTypes;
8393
}
8494

8595
export interface IdentifyResult {

src/index.ts

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { parse, EXECUTION_TYPES } from './parser';
1+
import { parse, EXECUTION_TYPES, defaultParamTypesFor } from './parser';
22
import { DIALECTS } from './defines';
33
import type { ExecutionType, IdentifyOptions, IdentifyResult, StatementType } from './defines';
44

@@ -21,7 +21,11 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
2121
throw new Error(`Unknown dialect. Allowed values: ${DIALECTS.join(', ')}`);
2222
}
2323

24-
const result = parse(query, isStrict, dialect, options.identifyTables);
24+
// Default parameter types for each dialect
25+
const paramTypes = options.paramTypes || defaultParamTypesFor(dialect);
26+
27+
const result = parse(query, isStrict, dialect, options.identifyTables, paramTypes);
28+
const sort = dialect === 'psql' && !options.paramTypes;
2529

2630
return result.body.map((statement) => {
2731
const result: IdentifyResult = {
@@ -31,7 +35,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
3135
type: statement.type,
3236
executionType: statement.executionType,
3337
// we want to sort the postgres params: $1 $2 $3, regardless of the order they appear
34-
parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters,
38+
parameters: sort ? statement.parameters.sort() : statement.parameters,
3539
tables: statement.tables || [],
3640
};
3741
return result;

src/parser.ts

+32-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import type {
99
Step,
1010
ParseResult,
1111
ConcreteStatement,
12+
ParamTypes,
1213
} from './defines';
1314

1415
interface StatementParser {
@@ -144,6 +145,7 @@ export function parse(
144145
isStrict = true,
145146
dialect: Dialect = 'generic',
146147
identifyTables = false,
148+
paramTypes?: ParamTypes,
147149
): ParseResult {
148150
const topLevelState = initState({ input });
149151
const topLevelStatement: ParseResult = {
@@ -174,7 +176,7 @@ export function parse(
174176

175177
while (prevState.position < topLevelState.end) {
176178
const tokenState = initState({ prevState });
177-
const token = scanToken(tokenState, dialect);
179+
const token = scanToken(tokenState, dialect, paramTypes);
178180
const nextToken = nextNonWhitespaceToken(tokenState, dialect);
179181

180182
if (!statementParser) {
@@ -1013,3 +1015,32 @@ function stateMachineStatementParser(
10131015
},
10141016
};
10151017
}
1018+
1019+
export function defaultParamTypesFor(dialect: Dialect): ParamTypes {
1020+
switch (dialect) {
1021+
case 'psql':
1022+
return {
1023+
numbered: ['$'],
1024+
};
1025+
case 'mssql':
1026+
return {
1027+
named: [':'],
1028+
};
1029+
case 'bigquery':
1030+
return {
1031+
positional: true,
1032+
named: ['@'],
1033+
quoted: ['@'],
1034+
};
1035+
case 'sqlite':
1036+
return {
1037+
positional: true,
1038+
numbered: ['?'],
1039+
named: [':', '@'],
1040+
};
1041+
default:
1042+
return {
1043+
positional: true,
1044+
};
1045+
}
1046+
}

src/tokenizer.ts

+106-42
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Tokenizer
33
*/
44

5-
import type { Token, State, Dialect } from './defines';
5+
import type { Token, State, Dialect, ParamTypes } from './defines';
66

77
type Char = string | null;
88

@@ -76,7 +76,11 @@ const ENDTOKENS: Record<string, Char> = {
7676
'[': ']',
7777
};
7878

79-
export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
79+
export function scanToken(
80+
state: State,
81+
dialect: Dialect = 'generic',
82+
paramTypes: ParamTypes = { positional: true },
83+
): Token {
8084
const ch = read(state);
8185

8286
if (isWhitespace(ch)) {
@@ -95,8 +99,8 @@ export function scanToken(state: State, dialect: Dialect = 'generic'): Token {
9599
return scanString(state, ENDTOKENS[ch]);
96100
}
97101

98-
if (isParameter(ch, state, dialect)) {
99-
return scanParameter(state, dialect);
102+
if (isParameter(ch, state, paramTypes)) {
103+
return scanParameter(state, dialect, paramTypes);
100104
}
101105

102106
if (isDollarQuotedString(state)) {
@@ -253,52 +257,92 @@ function scanString(state: State, endToken: Char): Token {
253257
};
254258
}
255259

256-
function scanParameter(state: State, dialect: Dialect): Token {
257-
if (['mysql', 'generic', 'sqlite'].includes(dialect)) {
258-
return {
259-
type: 'parameter',
260-
value: state.input.slice(state.start, state.position + 1),
261-
start: state.start,
262-
end: state.start,
263-
};
264-
}
260+
function getCustomParam(state: State, paramTypes: ParamTypes): string | null | undefined {
261+
const matches = paramTypes?.custom
262+
?.map((regex) => {
263+
const reg = new RegExp(`^(?:${regex})`, 'u');
264+
return reg.exec(state.input.slice(state.start));
265+
})
266+
.filter((value) => !!value)[0];
265267

266-
if (dialect === 'psql') {
267-
let nextChar: Char;
268+
return matches ? matches[0] : null;
269+
}
268270

269-
do {
270-
nextChar = read(state);
271-
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
271+
function scanParameter(state: State, dialect: Dialect, paramTypes: ParamTypes): Token {
272+
const curCh = state.input[state.start];
273+
const nextChar = peek(state);
274+
let matched = false;
275+
276+
if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => type === curCh)) {
277+
const endIndex = state.input
278+
.slice(state.start + 1)
279+
.split('')
280+
.findIndex((val) => /^\W+/.test(val));
281+
const maybeNumbers = state.input.slice(
282+
state.start + 1,
283+
endIndex > 0 ? state.start + endIndex + 1 : state.end + 1,
284+
);
285+
if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) {
286+
let nextChar: Char = null;
287+
do {
288+
nextChar = read(state);
289+
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
290+
291+
if (nextChar !== null) unread(state);
292+
matched = true;
293+
}
294+
}
272295

273-
if (nextChar !== null) unread(state);
296+
if (!matched && paramTypes.named?.length && paramTypes.named.some((type) => type === curCh)) {
297+
if (!isQuotedIdentifier(nextChar, dialect)) {
298+
while (isAlphaNumeric(peek(state))) read(state);
299+
matched = true;
300+
}
301+
}
274302

275-
const value = state.input.slice(state.start, state.position + 1);
303+
if (!matched && paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === curCh)) {
304+
if (isQuotedIdentifier(nextChar, dialect)) {
305+
const quoteChar = read(state) as string;
306+
// end when we reach the end quote
307+
while (
308+
(isAlphaNumeric(peek(state)) || peek(state) === ' ') &&
309+
peek(state) != ENDTOKENS[quoteChar]
310+
) {
311+
read(state);
312+
}
276313

277-
return {
278-
type: 'parameter',
279-
value,
280-
start: state.start,
281-
end: state.start + value.length - 1,
282-
};
314+
// read the end quote
315+
read(state);
316+
317+
matched = true;
318+
}
283319
}
284320

285-
if (dialect === 'mssql') {
286-
while (isAlphaNumeric(peek(state))) read(state);
321+
if (!matched && paramTypes.custom && paramTypes.custom.length) {
322+
const custom = getCustomParam(state, paramTypes);
323+
324+
if (custom) {
325+
read(state, custom.length);
326+
matched = true;
327+
}
328+
}
329+
const value = state.input.slice(state.start, state.position + 1);
287330

288-
const value = state.input.slice(state.start, state.position + 1);
331+
if (!matched && !paramTypes.positional && curCh !== '?') {
332+
// not positional, panic
289333
return {
290-
type: 'parameter',
291-
value,
334+
type: 'unknown',
335+
value: value,
292336
start: state.start,
293337
end: state.start + value.length - 1,
294338
};
295339
}
296340

297341
return {
298342
type: 'parameter',
299-
value: 'unknown',
343+
value,
300344
start: state.start,
301-
end: state.end,
345+
end: state.start + value.length - 1,
302346
};
303347
}
304348

@@ -413,18 +457,38 @@ function isString(ch: Char, dialect: Dialect): boolean {
413457
return stringStart.includes(ch);
414458
}
415459

416-
function isParameter(ch: Char, state: State, dialect: Dialect): boolean {
417-
let pStart = '?'; // ansi standard - sqlite, mysql
418-
if (dialect === 'psql') {
419-
pStart = '$';
420-
const nextChar = peek(state);
421-
if (nextChar === null || isNaN(Number(nextChar))) {
422-
return false;
460+
function isCustomParam(state: State, customParamType: NonNullable<ParamTypes['custom']>): boolean {
461+
return customParamType.some((regex) => {
462+
const reg = new RegExp(`^(?:${regex})`, 'uy');
463+
return reg.test(state.input.slice(state.start));
464+
});
465+
}
466+
467+
function isParameter(ch: Char, state: State, paramTypes: ParamTypes): boolean {
468+
if (!ch) {
469+
return false;
470+
}
471+
const nextChar = peek(state);
472+
if (paramTypes.positional && ch === '?') return true;
473+
474+
if (paramTypes.numbered?.length && paramTypes.numbered.some((type) => ch === type)) {
475+
if (nextChar !== null && !isNaN(Number(nextChar))) {
476+
return true;
423477
}
424478
}
425-
if (dialect === 'mssql') pStart = ':';
426479

427-
return ch === pStart;
480+
if (
481+
(paramTypes.named?.length && paramTypes.named.some((type) => type === ch)) ||
482+
(paramTypes.quoted?.length && paramTypes.quoted.some((type) => type === ch))
483+
) {
484+
return true;
485+
}
486+
487+
if (paramTypes.custom?.length && isCustomParam(state, paramTypes.custom)) {
488+
return true;
489+
}
490+
491+
return false;
428492
}
429493

430494
function isDollarQuotedString(state: State): boolean {

test/index.spec.ts

+44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { Dialect, getExecutionType, identify } from '../src/index';
22
import { expect } from 'chai';
3+
import { ParamTypes } from '../src/defines';
34

45
describe('identify', () => {
56
it('should throw error for invalid dialect', () => {
@@ -22,6 +23,49 @@ describe('identify', () => {
2223
]);
2324
});
2425

26+
it('should identify custom parameters', () => {
27+
const paramTypes: ParamTypes = {
28+
positional: true,
29+
numbered: ['$'],
30+
named: [':'],
31+
quoted: [':'],
32+
custom: ['\\{[a-zA-Z0-9_]+\\}'],
33+
};
34+
const query = `SELECT * FROM foo WHERE bar = ? AND baz = $1 AND fizz = :fizzz AND buzz = :"buzz buzz" AND foo2 = {fooo}`;
35+
36+
expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
37+
{
38+
start: 0,
39+
end: 104,
40+
text: query,
41+
type: 'SELECT',
42+
executionType: 'LISTING',
43+
parameters: ['?', '$1', ':fizzz', ':"buzz buzz"', '{fooo}'],
44+
tables: [],
45+
},
46+
]);
47+
});
48+
49+
it('custom params should override defaults for dialect', () => {
50+
const paramTypes: ParamTypes = {
51+
positional: true,
52+
};
53+
54+
const query = 'SELECT * FROM foo WHERE bar = $1 AND bar = :named AND fizz = :`quoted`';
55+
56+
expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
57+
{
58+
start: 0,
59+
end: 69,
60+
text: query,
61+
type: 'SELECT',
62+
executionType: 'LISTING',
63+
parameters: [],
64+
tables: [],
65+
},
66+
]);
67+
});
68+
2569
it('should identify tables in simple for basic cases', () => {
2670
expect(
2771
identify('SELECT * FROM foo JOIN bar ON foo.id = bar.id', { identifyTables: true }),

0 commit comments

Comments
 (0)