Skip to content

.Net: Net MEVD: Sqlite escape table and column names, treat empty batches as NOP #11252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: feature-vector-data-preb2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ internal abstract class SqliteWhereCondition(string operand, List<object> values
public abstract string BuildQuery(List<string> parameterNames);

protected string GetOperand() => !string.IsNullOrWhiteSpace(this.TableName) ?
$"{this.TableName}.{this.Operand}" :
this.Operand;
$"[{this.TableName}].[{this.Operand}]" :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In SQLite, the standard quoting mechanism is double-quoted (the square brackets is more of a SQL Server thing). In fact I'm surprised this works!

Also, don't we need to escape TableName and Operand, in case they contain special characters?

I'd maybe suggest having a single method somewhere (RenderSqlIdentifier?) which does the quoting and escaping, and used from everywhere (for column, table, index names...).

$"[{this.Operand}]";
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str
{
var builder = new StringBuilder();

builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} (");
builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] (");

builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition)));
builder.Append(");");
Expand All @@ -48,7 +48,8 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str
{
if (column.HasIndex)
{
builder.AppendLine($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName}_{column.Name}_index ON {tableName}({column.Name});");
builder.AppendLine();
builder.Append($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}_{column.Name}_index] ON [{tableName}]([{column.Name}]);");
}
}

Expand All @@ -68,7 +69,7 @@ public static DbCommand BuildCreateVirtualTableCommand(
{
var builder = new StringBuilder();

builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} USING {extensionName}(");
builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] USING {extensionName}(");

builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition)));
builder.Append(");");
Expand Down Expand Up @@ -113,7 +114,7 @@ public static DbCommand BuildInsertCommand(
records[recordIndex],
recordIndex);

builder.AppendLine($"INSERT{replacePlaceholder} INTO {tableName} ({string.Join(", ", columns)})");
builder.AppendLine($"INSERT{replacePlaceholder} INTO [{tableName}] ({string.Join(", ", columns)})");
builder.AppendLine($"VALUES ({string.Join(", ", parameters)})");
builder.AppendLine($"RETURNING {rowIdentifier};");

Expand All @@ -139,8 +140,14 @@ public static DbCommand BuildSelectCommand(

var (command, whereClause) = GetCommandWithWhereClause(connection, conditions);

builder.AppendLine($"SELECT {string.Join(", ", columnNames)}");
builder.AppendLine($"FROM {tableName}");
builder.Append("SELECT ");
foreach (var columnName in columnNames)
{
builder.AppendFormat("[{0}],", columnName);
}
builder.Length--; // Remove the last comma
builder.AppendLine();
builder.AppendLine($"FROM [{tableName}]");

AppendWhereClauseIfExists(builder, whereClause);
AppendOrderByIfExists(builder, orderByPropertyName);
Expand All @@ -166,15 +173,15 @@ public static DbCommand BuildSelectLeftJoinCommand(

List<string> propertyNames =
[
.. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"),
.. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"),
.. leftTablePropertyNames.Select(property => $"[{leftTable}].[{property}]"),
.. rightTablePropertyNames.Select(property => $"[{rightTable}].[{property}]"),
];

var (command, whereClause) = GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters);

builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}");
builder.AppendLine($"FROM {leftTable} ");
builder.AppendLine($"LEFT JOIN {rightTable} ON {leftTable}.{joinColumnName} = {rightTable}.{joinColumnName}");
builder.AppendLine($"FROM [{leftTable}] ");
builder.AppendLine($"LEFT JOIN [{rightTable}] ON [{leftTable}].[{joinColumnName}] = [{rightTable}].[{joinColumnName}]");

AppendWhereClauseIfExists(builder, whereClause);
AppendOrderByIfExists(builder, orderByPropertyName);
Expand Down Expand Up @@ -301,7 +308,7 @@ private static (List<string> Columns, List<string> ParameterNames, List<object?>
{
if (record.TryGetValue(propertyName, out var value))
{
columns.Add(propertyName);
columns.Add($"[{propertyName}]");
parameterNames.Add(GetParameterName(propertyName, index));
parameterValues.Add(value ?? DBNull.Value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,14 @@ public virtual Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>
/// <inheritdoc />
public async IAsyncEnumerable<TRecord> GetBatchAsync(IEnumerable<ulong> keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var keysList = GetKeysAsListOfObjects(keys);
if (keysList.Count == 0)
{
yield break;
}

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false))
await foreach (var record in this.InternalGetBatchAsync(connection, keysList, options, cancellationToken).ConfigureAwait(false))
{
yield return record;
}
Expand All @@ -257,6 +263,8 @@ public async Task<ulong> UpsertAsync(TRecord record, CancellationToken cancellat
/// <inheritdoc />
public async IAsyncEnumerable<ulong> UpsertBatchAsync(IEnumerable<TRecord> records, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Verify.NotNull(records);

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var record in this.InternalUpsertBatchAsync<ulong>(connection, records, cancellationToken)
.ConfigureAwait(false))
Expand All @@ -275,8 +283,14 @@ public async Task DeleteAsync(ulong key, CancellationToken cancellationToken = d
/// <inheritdoc />
public async Task DeleteBatchAsync(IEnumerable<ulong> keys, CancellationToken cancellationToken = default)
{
var keysList = GetKeysAsListOfObjects(keys);
if (keysList.Count == 0)
{
return;
}

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false);
await this.InternalDeleteBatchAsync(connection, keysList, cancellationToken).ConfigureAwait(false);
}

#endregion
Expand All @@ -293,8 +307,14 @@ public async Task DeleteBatchAsync(IEnumerable<ulong> keys, CancellationToken ca
/// <inheritdoc />
public async IAsyncEnumerable<TRecord> GetBatchAsync(IEnumerable<string> keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var keysList = GetKeysAsListOfObjects(keys);
if (keysList.Count == 0)
{
yield break;
}

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false))
await foreach (var record in this.InternalGetBatchAsync(connection, keysList, options, cancellationToken).ConfigureAwait(false))
{
yield return record;
}
Expand All @@ -313,6 +333,8 @@ async IAsyncEnumerable<string> IVectorStoreRecordCollection<string, TRecord>.Ups
IEnumerable<TRecord> records,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
Verify.NotNull(records);

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var record in this.InternalUpsertBatchAsync<string>(connection, records, cancellationToken)
.ConfigureAwait(false))
Expand All @@ -324,6 +346,8 @@ async IAsyncEnumerable<string> IVectorStoreRecordCollection<string, TRecord>.Ups
/// <inheritdoc />
public async Task DeleteAsync(string key, CancellationToken cancellationToken = default)
{
Verify.NotNull(key);

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await this.InternalDeleteAsync(connection, key, cancellationToken)
.ConfigureAwait(false);
Expand All @@ -332,8 +356,14 @@ await this.InternalDeleteAsync(connection, key, cancellationToken)
/// <inheritdoc />
public async Task DeleteBatchAsync(IEnumerable<string> keys, CancellationToken cancellationToken = default)
{
var keysList = GetKeysAsListOfObjects(keys);
if (keysList.Count == 0)
{
return;
}

using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false);
await this.InternalDeleteBatchAsync(connection, keysList, cancellationToken).ConfigureAwait(false);
}

#endregion
Expand Down Expand Up @@ -480,18 +510,12 @@ private Task<int> DropTableAsync(SqliteConnection connection, string tableName,
.ConfigureAwait(false);
}

private IAsyncEnumerable<TRecord> InternalGetBatchAsync<TKey>(
private IAsyncEnumerable<TRecord> InternalGetBatchAsync(
SqliteConnection connection,
IEnumerable<TKey> keys,
List<object> keysList,
GetRecordOptions? options,
CancellationToken cancellationToken)
{
Verify.NotNull(keys);

var keysList = keys.Cast<object>().ToList();

Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero.");

var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keysList)
{
TableName = this._dataTableName
Expand Down Expand Up @@ -583,6 +607,11 @@ private IAsyncEnumerable<TKey> InternalUpsertBatchAsync<TKey>(SqliteConnection c
OperationName,
() => this._mapper.MapFromDataToStorageModel(record))).ToList();

if (storageModels.Count == 0)
{
return AsyncEnumerable.Empty<TKey>();
}

var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList();

var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keys);
Expand Down Expand Up @@ -643,26 +672,18 @@ private async IAsyncEnumerable<TKey> InternalUpsertBatchAsync<TKey>(
}
}

private Task InternalDeleteAsync<TKey>(SqliteConnection connection, TKey key, CancellationToken cancellationToken)
private Task InternalDeleteAsync<TKey>(SqliteConnection connection, TKey key, CancellationToken cancellationToken) where TKey : notnull
{
Verify.NotNull(key);

var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key);

return this.InternalDeleteBatchAsync(connection, condition, cancellationToken);
}

private Task InternalDeleteBatchAsync<TKey>(SqliteConnection connection, IEnumerable<TKey> keys, CancellationToken cancellationToken)
private Task InternalDeleteBatchAsync(SqliteConnection connection, List<object> keys, CancellationToken cancellationToken)
{
Verify.NotNull(keys);

var keysList = keys.Cast<object>().ToList();

Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero.");

var condition = new SqliteWhereInCondition(
this._propertyReader.KeyPropertyStoragePropertyName,
keysList);
keys);

return this.InternalDeleteBatchAsync(connection, condition, cancellationToken);
}
Expand Down Expand Up @@ -811,4 +832,11 @@ private static string GetVectorTableName(
}

#endregion

private static List<object> GetKeysAsListOfObjects<TKey>(IEnumerable<TKey> keys)
{
Verify.NotNull(keys);

return keys.Cast<object>().ToList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ public void SqliteWhereEqualsConditionWithoutParameterNamesThrowsException()
}

[Theory]
[InlineData(null, "Name = @Name0")]
[InlineData("", "Name = @Name0")]
[InlineData("TableName", "TableName.Name = @Name0")]
[InlineData(null, "[Name] = @Name0")]
[InlineData("", "[Name] = @Name0")]
[InlineData("TableName", "[TableName].[Name] = @Name0")]
public void SqliteWhereEqualsConditionBuildsValidQuery(string? tableName, string expectedQuery)
{
// Arrange
Expand All @@ -48,9 +48,9 @@ public void SqliteWhereInConditionWithoutParameterNamesThrowsException()
}

[Theory]
[InlineData(null, "Name IN (@Name0, @Name1)")]
[InlineData("", "Name IN (@Name0, @Name1)")]
[InlineData("TableName", "TableName.Name IN (@Name0, @Name1)")]
[InlineData(null, "[Name] IN (@Name0, @Name1)")]
[InlineData("", "[Name] IN (@Name0, @Name1)")]
[InlineData("TableName", "[TableName].[Name] IN (@Name0, @Name1)")]
public void SqliteWhereInConditionBuildsValidQuery(string? tableName, string expectedQuery)
{
// Arrange
Expand All @@ -74,9 +74,9 @@ public void SqliteWhereMatchConditionWithoutParameterNamesThrowsException()
}

[Theory]
[InlineData(null, "Name MATCH @Name0")]
[InlineData("", "Name MATCH @Name0")]
[InlineData("TableName", "TableName.Name MATCH @Name0")]
[InlineData(null, "[Name] MATCH @Name0")]
[InlineData("", "[Name] MATCH @Name0")]
[InlineData("TableName", "[TableName].[Name] MATCH @Name0")]
public void SqliteWhereMatchConditionBuildsValidQuery(string? tableName, string expectedQuery)
{
// Arrange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void ItBuildsInsertCommand(bool replaceIfExists)
// Assert
Assert.Equal(replaceIfExists, command.CommandText.Contains("OR REPLACE"));

Assert.Contains($"INTO {TableName} (Id, Name, Age, Address)", command.CommandText);
Assert.Contains($"INTO [{TableName}] ([Id], [Name], [Age], [Address])", command.CommandText);
Assert.Contains("VALUES (@Id0, @Name0, @Age0, @Address0)", command.CommandText);
Assert.Contains("VALUES (@Id1, @Name1, @Age1, @Address1)", command.CommandText);
Assert.Contains("RETURNING Id", command.CommandText);
Expand Down Expand Up @@ -184,11 +184,11 @@ public void ItBuildsSelectCommand(string? orderByPropertyName)
var command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectCommand(this._connection, TableName, columnNames, conditions, orderByPropertyName);

// Assert
Assert.Contains("SELECT Id, Name, Age, Address", command.CommandText);
Assert.Contains($"FROM {TableName}", command.CommandText);
Assert.Contains("SELECT [Id],[Name],[Age],[Address]", command.CommandText);
Assert.Contains($"FROM [{TableName}]", command.CommandText);

Assert.Contains("Name = @Name0", command.CommandText);
Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText);
Assert.Contains("[Name] = @Name0", command.CommandText);
Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText);

Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}"));

Expand Down Expand Up @@ -239,13 +239,13 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName)
orderByPropertyName);

// Assert
Assert.Contains("SELECT LeftTable.Id, LeftTable.Name, RightTable.Age, RightTable.Address", command.CommandText);
Assert.Contains("FROM LeftTable", command.CommandText);
Assert.Contains("SELECT [LeftTable].[Id], [LeftTable].[Name], [RightTable].[Age], [RightTable].[Address]", command.CommandText);
Assert.Contains("FROM [LeftTable]", command.CommandText);

Assert.Contains("LEFT JOIN RightTable ON LeftTable.Id = RightTable.Id", command.CommandText);
Assert.Contains("LEFT JOIN [RightTable] ON [LeftTable].[Id] = [RightTable].[Id]", command.CommandText);

Assert.Contains("Name = @Name0", command.CommandText);
Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText);
Assert.Contains("[Name] = @Name0", command.CommandText);
Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText);

Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}"));

Expand Down Expand Up @@ -280,8 +280,8 @@ public void ItBuildsDeleteCommand()
// Assert
Assert.Contains("DELETE FROM [TestTable]", command.CommandText);

Assert.Contains("Name = @Name0", command.CommandText);
Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText);
Assert.Contains("[Name] = @Name0", command.CommandText);
Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText);

Assert.Equal("@Name0", command.Parameters[0].ParameterName);
Assert.Equal("NameValue", command.Parameters[0].Value);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.

using SqliteIntegrationTests.Support;
using VectorDataSpecificationTests.CRUD;
using Xunit;

namespace SqliteIntegrationTests.CRUD;

public class SqliteBatchConformanceTests_string(SqliteSimpleModelFixture<string> fixture)
: BatchConformanceTests<string>(fixture), IClassFixture<SqliteSimpleModelFixture<string>>
{
}

public class SqliteBatchConformanceTests_ulong(SqliteSimpleModelFixture<ulong> fixture)
: BatchConformanceTests<ulong>(fixture), IClassFixture<SqliteSimpleModelFixture<ulong>>
{
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.

using SqliteIntegrationTests.Support;
using VectorDataSpecificationTests.CRUD;
using Xunit;

namespace SqliteIntegrationTests.CRUD;

public class SqliteRecordConformanceTests_string(SqliteSimpleModelFixture<string> fixture)
: RecordConformanceTests<string>(fixture), IClassFixture<SqliteSimpleModelFixture<string>>
{
}

public class SqliteRecordConformanceTests_ulong(SqliteSimpleModelFixture<ulong> fixture)
: RecordConformanceTests<ulong>(fixture), IClassFixture<SqliteSimpleModelFixture<ulong>>
{
public override async Task GetAsyncThrowsArgumentNullExceptionForNullKey()
{
// default(ulong) is a valid key
var result = await fixture.Collection.GetAsync(default);
Assert.Null(result);
}
}
Loading
Loading