Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Net MEVD: Sqlite escape table and column names, treat empty batches as NOP #11252

Open
wants to merge 6 commits into
base: feature-vector-data-preb2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,24 @@ internal static SqlCommand CreateTable(
sb.AppendTableName(schema, tableName);
sb.AppendLine(" (");
string keyColumnName = GetColumnName(keyProperty);
sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty));
sb.AppendEscaped(keyColumnName);
sb.AppendFormat(" {0} NOT NULL,", Map(keyProperty));
sb.AppendLine();
for (int i = 0; i < dataProperties.Count; i++)
{
sb.AppendFormat("[{0}] {1},", GetColumnName(dataProperties[i]), Map(dataProperties[i]));
sb.AppendEscaped(GetColumnName(dataProperties[i]));
sb.AppendFormat(" {0},", Map(dataProperties[i]));
sb.AppendLine();
}
for (int i = 0; i < vectorProperties.Count; i++)
{
sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions);
sb.AppendEscaped(GetColumnName(vectorProperties[i]));
sb.AppendFormat(" VECTOR({0}),", vectorProperties[i].Dimensions);
sb.AppendLine();
}
sb.AppendFormat("PRIMARY KEY ([{0}])", keyColumnName);
sb.AppendLine();
sb.Append("PRIMARY KEY (");
sb.AppendEscaped(keyColumnName);
sb.AppendLine(")");
sb.AppendLine(");"); // end the table definition

foreach (var dataProperty in dataProperties)
Expand All @@ -57,8 +61,9 @@ internal static SqlCommand CreateTable(
sb.AppendFormat("CREATE INDEX ");
sb.AppendIndexName(tableName, GetColumnName(dataProperty));
sb.AppendFormat(" ON ").AppendTableName(schema, tableName);
sb.AppendFormat("([{0}]);", GetColumnName(dataProperty));
sb.AppendLine();
sb.Append('(');
sb.AppendEscaped(GetColumnName(dataProperty));
sb.AppendLine(");");
}
}

Expand Down Expand Up @@ -140,14 +145,18 @@ internal static SqlCommand MergeIntoSingle(
sb.Append(") AS s (");
sb.AppendColumnNames(properties);
sb.AppendLine(")");
sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine();
sb.Append("ON (t.").AppendEscaped(GetColumnName(keyProperty));
sb.Append(" = s.").AppendEscaped(GetColumnName(keyProperty));
sb.AppendLine(")");
sb.AppendLine("WHEN MATCHED THEN");
sb.Append("UPDATE SET ");
foreach (VectorStoreRecordProperty property in properties)
{
if (property != keyProperty) // don't update the key
{
sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property));
sb.Append("t.").AppendEscaped(GetColumnName(property));
sb.Append(" = s.").AppendEscaped(GetColumnName(property));
sb.Append(',');
}
}
--sb.Length; // remove the last comma
Expand All @@ -161,7 +170,7 @@ internal static SqlCommand MergeIntoSingle(
sb.Append("VALUES (");
sb.AppendColumnNames(properties, prefix: "s.");
sb.AppendLine(")");
sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty));
sb.Append("OUTPUT inserted.").AppendEscaped(GetColumnName(keyProperty)).Append(';');

command.CommandText = sb.ToString();
return command;
Expand Down Expand Up @@ -208,14 +217,18 @@ internal static bool MergeIntoMany(
sb.Append(") AS s ("); // s stands for source
sb.AppendColumnNames(properties);
sb.AppendLine(")");
sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine();
sb.Append("ON (t.").AppendEscaped(GetColumnName(keyProperty));
sb.Append(" = s.").AppendEscaped(GetColumnName(keyProperty));
sb.AppendLine(")");
sb.AppendLine("WHEN MATCHED THEN");
sb.Append("UPDATE SET ");
foreach (VectorStoreRecordProperty property in properties)
{
if (property != keyProperty) // don't update the key
{
sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property));
sb.Append("t.").AppendEscaped(GetColumnName(property));
sb.Append(" = s.").AppendEscaped(GetColumnName(property));
sb.Append(',');
}
}
--sb.Length; // remove the last comma
Expand All @@ -228,8 +241,8 @@ internal static bool MergeIntoMany(
sb.Append("VALUES (");
sb.AppendColumnNames(properties, prefix: "s.");
sb.AppendLine(")");
sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", GetColumnName(keyProperty));
sb.AppendLine();
sb.Append("OUTPUT inserted.").AppendEscaped(GetColumnName(keyProperty));
sb.AppendLine(" INTO @InsertedKeys (KeyColumn);");

// The SELECT statement returns the keys of the inserted rows.
sb.Append("SELECT KeyColumn FROM @InsertedKeys;");
Expand All @@ -248,7 +261,7 @@ internal static SqlCommand DeleteSingle(
StringBuilder sb = new(100);
sb.Append("DELETE FROM ");
sb.AppendTableName(schema, tableName);
sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty));
sb.Append(" WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" = ");
sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName);
command.AddParameter(keyProperty, keyParamName, key);

Expand All @@ -263,7 +276,7 @@ internal static bool DeleteMany<TKey>(
StringBuilder sb = new(100);
sb.Append("DELETE FROM ");
sb.AppendTableName(schema, tableName);
sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty));
sb.Append(" WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" IN (");
sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys);
sb.Append(')'); // close the IN clause

Expand Down Expand Up @@ -293,7 +306,7 @@ internal static SqlCommand SelectSingle(
sb.Append("FROM ");
sb.AppendTableName(schema, collectionName);
sb.AppendLine();
sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty));
sb.Append("WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" = ");
sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName);
command.AddParameter(keyProperty, keyParamName, key);

Expand All @@ -315,7 +328,7 @@ internal static bool SelectMany<TKey>(
sb.Append("FROM ");
sb.AppendTableName(schema, tableName);
sb.AppendLine();
sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty));
sb.Append("WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" IN (");
sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys);
sb.Append(')'); // close the IN clause

Expand Down Expand Up @@ -346,8 +359,9 @@ internal static SqlCommand SelectVector<TRecord>(
sb.AppendFormat("SELECT ");
sb.AppendColumnNames(properties, includeVectors: options.IncludeVectors);
sb.AppendLine(",");
sb.AppendFormat("VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]",
distanceMetric, GetColumnName(vectorProperty), vector.Length);
sb.AppendFormat("VECTOR_DISTANCE('{0}', ", distanceMetric);
sb.AppendEscaped(GetColumnName(vectorProperty));
sb.AppendFormat(", CAST(@vector AS VECTOR({0}))) AS [score]", vector.Length);
sb.AppendLine();
sb.Append("FROM ");
sb.AppendTableName(schema, tableName);
Expand Down Expand Up @@ -411,25 +425,29 @@ internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorS
return sb;
}

internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName)
internal static StringBuilder AppendEscaped(this StringBuilder sb, string value)
{
// If the column name contains a ], then escape it by doubling it.
// If the value contains a ], then escape it by doubling it.
// "Name with [brackets]" becomes [Name with [brackets]]].

sb.Append('[');
int index = sb.Length; // store the index, so we replace ] only for the appended part
sb.Append(value);
sb.Replace("]", "]]", index, value.Length); // replace the ] for value (escape it)
sb.Append(']');

return sb;
}

internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName)
{
if (!string.IsNullOrEmpty(schema))
{
sb.Append(schema);
sb.Replace("]", "]]", index, schema!.Length); // replace the ] for schema
sb.Append("].[");
index = sb.Length;
sb.AppendEscaped(schema!);
sb.Append('.');
}

sb.Append(tableName);
sb.Replace("]", "]]", index, tableName.Length);
sb.Append(']');
sb.AppendEscaped(tableName);

return sb;
}
Expand All @@ -452,7 +470,8 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb,
sb.Append(prefix);
}
// Use square brackets to escape column names.
sb.AppendFormat("[{0}],", GetColumnName(property));
sb.AppendEscaped(GetColumnName(property));
sb.Append(',');
any = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ internal abstract class SqliteWhereCondition(string operand, List<object> values
public abstract string BuildQuery(List<string> parameterNames);

protected string GetOperand() => !string.IsNullOrWhiteSpace(this.TableName) ?
$"{this.TableName}.{this.Operand}" :
this.Operand;
$"[{this.TableName}].[{this.Operand}]" :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SQLite, the standard quoting mechanism is double-quoted (the square brackets is more of a SQL Server thing). In fact I'm surprised this works!

Also, don't we need to escape TableName and Operand, in case they contain special characters?

I'd maybe suggest having a single method somewhere (RenderSqlIdentifier?) which does the quoting and escaping, and used from everywhere (for column, table, index names...).

$"[{this.Operand}]";
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str
{
var builder = new StringBuilder();

builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} (");
builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] (");

builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition)));
builder.Append(");");
Expand All @@ -48,7 +48,8 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str
{
if (column.HasIndex)
{
builder.AppendLine($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName}_{column.Name}_index ON {tableName}({column.Name});");
builder.AppendLine();
builder.Append($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}_{column.Name}_index] ON [{tableName}]([{column.Name}]);");
}
}

Expand All @@ -68,7 +69,7 @@ public static DbCommand BuildCreateVirtualTableCommand(
{
var builder = new StringBuilder();

builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} USING {extensionName}(");
builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] USING {extensionName}(");

builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition)));
builder.Append(");");
Expand Down Expand Up @@ -113,7 +114,7 @@ public static DbCommand BuildInsertCommand(
records[recordIndex],
recordIndex);

builder.AppendLine($"INSERT{replacePlaceholder} INTO {tableName} ({string.Join(", ", columns)})");
builder.AppendLine($"INSERT{replacePlaceholder} INTO [{tableName}] ({string.Join(", ", columns)})");
builder.AppendLine($"VALUES ({string.Join(", ", parameters)})");
builder.AppendLine($"RETURNING {rowIdentifier};");

Expand All @@ -139,8 +140,14 @@ public static DbCommand BuildSelectCommand(

var (command, whereClause) = GetCommandWithWhereClause(connection, conditions);

builder.AppendLine($"SELECT {string.Join(", ", columnNames)}");
builder.AppendLine($"FROM {tableName}");
builder.Append("SELECT ");
foreach (var columnName in columnNames)
{
builder.AppendFormat("[{0}],", columnName);
}
builder.Length--; // Remove the last comma
builder.AppendLine();
builder.AppendLine($"FROM [{tableName}]");

AppendWhereClauseIfExists(builder, whereClause);
AppendOrderByIfExists(builder, orderByPropertyName);
Expand All @@ -166,15 +173,15 @@ public static DbCommand BuildSelectLeftJoinCommand(

List<string> propertyNames =
[
.. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"),
.. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"),
.. leftTablePropertyNames.Select(property => $"[{leftTable}].[{property}]"),
.. rightTablePropertyNames.Select(property => $"[{rightTable}].[{property}]"),
];

var (command, whereClause) = GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters);

builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}");
builder.AppendLine($"FROM {leftTable} ");
builder.AppendLine($"LEFT JOIN {rightTable} ON {leftTable}.{joinColumnName} = {rightTable}.{joinColumnName}");
builder.AppendLine($"FROM [{leftTable}] ");
builder.AppendLine($"LEFT JOIN [{rightTable}] ON [{leftTable}].[{joinColumnName}] = [{rightTable}].[{joinColumnName}]");

AppendWhereClauseIfExists(builder, whereClause);
AppendOrderByIfExists(builder, orderByPropertyName);
Expand Down Expand Up @@ -301,7 +308,7 @@ private static (List<string> Columns, List<string> ParameterNames, List<object?>
{
if (record.TryGetValue(propertyName, out var value))
{
columns.Add(propertyName);
columns.Add($"[{propertyName}]");
parameterNames.Add(GetParameterName(propertyName, index));
parameterValues.Add(value ?? DBNull.Value);
}
Expand Down
Loading
Loading