From 7513ab244851eadeaee5374941dec065114cbb26 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 28 Mar 2025 12:51:45 +0100 Subject: [PATCH 1/5] Sqlite: escape table and column names by using square brackets --- .../Conditions/SqliteWhereCondition.cs | 4 +-- ...liteVectorStoreCollectionCommandBuilder.cs | 29 ++++++++++++------- .../SqliteConditionsTests.cs | 18 ++++++------ ...ectorStoreCollectionCommandBuilderTests.cs | 24 +++++++-------- .../CRUD/SqliteRecordConformanceTests.cs | 12 ++++++++ .../Support/SqliteSimpleModelFixture.cs | 12 ++++++++ .../Support/SqliteTestStore.cs | 2 ++ 7 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs index ea3f702a42b8..71b3168efe72 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/Conditions/SqliteWhereCondition.cs @@ -15,6 +15,6 @@ internal abstract class SqliteWhereCondition(string operand, List values public abstract string BuildQuery(List parameterNames); protected string GetOperand() => !string.IsNullOrWhiteSpace(this.TableName) ? - $"{this.TableName}.{this.Operand}" : - this.Operand; + $"[{this.TableName}].[{this.Operand}]" : + $"[{this.Operand}]"; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 6707bf482fed..7a6d9a26b5f7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -39,7 +39,7 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str { var builder = new StringBuilder(); - builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} ("); + builder.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] ("); builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition))); builder.Append(");"); @@ -48,7 +48,8 @@ public static DbCommand BuildCreateTableCommand(SqliteConnection connection, str { if (column.HasIndex) { - builder.AppendLine($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName}_{column.Name}_index ON {tableName}({column.Name});"); + builder.AppendLine(); + builder.Append($"CREATE INDEX {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}_{column.Name}_index] ON [{tableName}]([{column.Name}]);"); } } @@ -68,7 +69,7 @@ public static DbCommand BuildCreateVirtualTableCommand( { var builder = new StringBuilder(); - builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}{tableName} USING {extensionName}("); + builder.AppendLine($"CREATE VIRTUAL TABLE {(ifNotExists ? "IF NOT EXISTS " : string.Empty)}[{tableName}] USING {extensionName}("); builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition))); builder.Append(");"); @@ -113,7 +114,7 @@ public static DbCommand BuildInsertCommand( records[recordIndex], recordIndex); - builder.AppendLine($"INSERT{replacePlaceholder} INTO {tableName} ({string.Join(", ", columns)})"); + builder.AppendLine($"INSERT{replacePlaceholder} INTO [{tableName}] ({string.Join(", ", columns)})"); builder.AppendLine($"VALUES ({string.Join(", ", parameters)})"); builder.AppendLine($"RETURNING {rowIdentifier};"); @@ -139,8 +140,14 @@ public static DbCommand BuildSelectCommand( var (command, whereClause) = GetCommandWithWhereClause(connection, conditions); - builder.AppendLine($"SELECT {string.Join(", ", columnNames)}"); - builder.AppendLine($"FROM {tableName}"); + builder.Append("SELECT "); + foreach (var columnName in columnNames) + { + builder.AppendFormat("[{0}],", columnName); + } + builder.Length--; // Remove the last comma + builder.AppendLine(); + builder.AppendLine($"FROM [{tableName}]"); AppendWhereClauseIfExists(builder, whereClause); AppendOrderByIfExists(builder, orderByPropertyName); @@ -166,15 +173,15 @@ public static DbCommand BuildSelectLeftJoinCommand( List propertyNames = [ - .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"), - .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), + .. leftTablePropertyNames.Select(property => $"[{leftTable}].[{property}]"), + .. rightTablePropertyNames.Select(property => $"[{rightTable}].[{property}]"), ]; var (command, whereClause) = GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters); builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}"); - builder.AppendLine($"FROM {leftTable} "); - builder.AppendLine($"LEFT JOIN {rightTable} ON {leftTable}.{joinColumnName} = {rightTable}.{joinColumnName}"); + builder.AppendLine($"FROM [{leftTable}] "); + builder.AppendLine($"LEFT JOIN [{rightTable}] ON [{leftTable}].[{joinColumnName}] = [{rightTable}].[{joinColumnName}]"); AppendWhereClauseIfExists(builder, whereClause); AppendOrderByIfExists(builder, orderByPropertyName); @@ -301,7 +308,7 @@ private static (List Columns, List ParameterNames, List { if (record.TryGetValue(propertyName, out var value)) { - columns.Add(propertyName); + columns.Add($"[{propertyName}]"); parameterNames.Add(GetParameterName(propertyName, index)); parameterValues.Add(value ?? DBNull.Value); } diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs index 7f02575e9b88..4e73d5ce0fe0 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteConditionsTests.cs @@ -22,9 +22,9 @@ public void SqliteWhereEqualsConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name = @Name0")] - [InlineData("", "Name = @Name0")] - [InlineData("TableName", "TableName.Name = @Name0")] + [InlineData(null, "[Name] = @Name0")] + [InlineData("", "[Name] = @Name0")] + [InlineData("TableName", "[TableName].[Name] = @Name0")] public void SqliteWhereEqualsConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange @@ -48,9 +48,9 @@ public void SqliteWhereInConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name IN (@Name0, @Name1)")] - [InlineData("", "Name IN (@Name0, @Name1)")] - [InlineData("TableName", "TableName.Name IN (@Name0, @Name1)")] + [InlineData(null, "[Name] IN (@Name0, @Name1)")] + [InlineData("", "[Name] IN (@Name0, @Name1)")] + [InlineData("TableName", "[TableName].[Name] IN (@Name0, @Name1)")] public void SqliteWhereInConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange @@ -74,9 +74,9 @@ public void SqliteWhereMatchConditionWithoutParameterNamesThrowsException() } [Theory] - [InlineData(null, "Name MATCH @Name0")] - [InlineData("", "Name MATCH @Name0")] - [InlineData("TableName", "TableName.Name MATCH @Name0")] + [InlineData(null, "[Name] MATCH @Name0")] + [InlineData("", "[Name] MATCH @Name0")] + [InlineData("TableName", "[TableName].[Name] MATCH @Name0")] public void SqliteWhereMatchConditionBuildsValidQuery(string? tableName, string expectedQuery) { // Arrange diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs index 7ad790f91089..42f14d5b4ca4 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs @@ -134,7 +134,7 @@ public void ItBuildsInsertCommand(bool replaceIfExists) // Assert Assert.Equal(replaceIfExists, command.CommandText.Contains("OR REPLACE")); - Assert.Contains($"INTO {TableName} (Id, Name, Age, Address)", command.CommandText); + Assert.Contains($"INTO [{TableName}] ([Id], [Name], [Age], [Address])", command.CommandText); Assert.Contains("VALUES (@Id0, @Name0, @Age0, @Address0)", command.CommandText); Assert.Contains("VALUES (@Id1, @Name1, @Age1, @Address1)", command.CommandText); Assert.Contains("RETURNING Id", command.CommandText); @@ -184,11 +184,11 @@ public void ItBuildsSelectCommand(string? orderByPropertyName) var command = SqliteVectorStoreCollectionCommandBuilder.BuildSelectCommand(this._connection, TableName, columnNames, conditions, orderByPropertyName); // Assert - Assert.Contains("SELECT Id, Name, Age, Address", command.CommandText); - Assert.Contains($"FROM {TableName}", command.CommandText); + Assert.Contains("SELECT [Id],[Name],[Age],[Address]", command.CommandText); + Assert.Contains($"FROM [{TableName}]", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("[Name] = @Name0", command.CommandText); + Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText); Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}")); @@ -239,13 +239,13 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName) orderByPropertyName); // Assert - Assert.Contains("SELECT LeftTable.Id, LeftTable.Name, RightTable.Age, RightTable.Address", command.CommandText); - Assert.Contains("FROM LeftTable", command.CommandText); + Assert.Contains("SELECT [LeftTable].[Id], [LeftTable].[Name], [RightTable].[Age], [RightTable].[Address]", command.CommandText); + Assert.Contains("FROM [LeftTable]", command.CommandText); - Assert.Contains("LEFT JOIN RightTable ON LeftTable.Id = RightTable.Id", command.CommandText); + Assert.Contains("LEFT JOIN [RightTable] ON [LeftTable].[Id] = [RightTable].[Id]", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("[Name] = @Name0", command.CommandText); + Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText); Assert.Equal(!string.IsNullOrWhiteSpace(orderByPropertyName), command.CommandText.Contains($"ORDER BY {orderByPropertyName}")); @@ -280,8 +280,8 @@ public void ItBuildsDeleteCommand() // Assert Assert.Contains("DELETE FROM [TestTable]", command.CommandText); - Assert.Contains("Name = @Name0", command.CommandText); - Assert.Contains("Age IN (@Age0, @Age1, @Age2)", command.CommandText); + Assert.Contains("[Name] = @Name0", command.CommandText); + Assert.Contains("[Age] IN (@Age0, @Age1, @Age2)", command.CommandText); Assert.Equal("@Name0", command.Parameters[0].ParameterName); Assert.Equal("NameValue", command.Parameters[0].Value); diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs new file mode 100644 index 000000000000..402c16b02621 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteRecordConformanceTests(SqliteSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs new file mode 100644 index 000000000000..345f4005cbe6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Support; + +public class SqliteSimpleModelFixture : SimpleModelFixture +{ + public override TestStore TestStore => SqliteTestStore.Instance; + + public override string DefaultDistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs index 9b025c66610f..3ea3b05d69d7 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -16,6 +16,8 @@ internal sealed class SqliteTestStore : TestStore public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + public override string DefaultDistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + private SqliteTestStore() { } From 5730353aae46841914f575515d4155779579a22f Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 28 Mar 2025 13:45:19 +0100 Subject: [PATCH 2/5] Sqlite: treat empty batches as nop as others connectors do --- .../SqliteVectorStoreRecordCollection.cs | 72 +++++++++++++------ .../CRUD/SqliteBatchConformanceTests.cs | 17 +++++ .../CRUD/SqliteRecordConformanceTests.cs | 15 +++- .../Support/SqliteSimpleModelFixture.cs | 3 +- .../CRUD/RecordConformanceTests.cs | 2 +- 5 files changed, 83 insertions(+), 26 deletions(-) create mode 100644 dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 16dbd7238aca..79c0226d881c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -240,8 +240,14 @@ public virtual Task> VectorizedSearchAsync /// public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + var keysList = GetKeysAsListOfObjects(keys); + if (keysList.Count == 0) + { + yield break; + } + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); - await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false)) + await foreach (var record in this.InternalGetBatchAsync(connection, keysList, options, cancellationToken).ConfigureAwait(false)) { yield return record; } @@ -257,6 +263,8 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancellat /// public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + Verify.NotNull(records); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken) .ConfigureAwait(false)) @@ -275,8 +283,14 @@ public async Task DeleteAsync(ulong key, CancellationToken cancellationToken = d /// public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { + var keysList = GetKeysAsListOfObjects(keys); + if (keysList.Count == 0) + { + return; + } + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); - await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false); + await this.InternalDeleteBatchAsync(connection, keysList, cancellationToken).ConfigureAwait(false); } #endregion @@ -293,8 +307,14 @@ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken ca /// public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + var keysList = GetKeysAsListOfObjects(keys); + if (keysList.Count == 0) + { + yield break; + } + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); - await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false)) + await foreach (var record in this.InternalGetBatchAsync(connection, keysList, options, cancellationToken).ConfigureAwait(false)) { yield return record; } @@ -313,6 +333,8 @@ async IAsyncEnumerable IVectorStoreRecordCollection.Ups IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken) { + Verify.NotNull(records); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken) .ConfigureAwait(false)) @@ -324,6 +346,8 @@ async IAsyncEnumerable IVectorStoreRecordCollection.Ups /// public async Task DeleteAsync(string key, CancellationToken cancellationToken = default) { + Verify.NotNull(key); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); await this.InternalDeleteAsync(connection, key, cancellationToken) .ConfigureAwait(false); @@ -332,8 +356,14 @@ await this.InternalDeleteAsync(connection, key, cancellationToken) /// public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { + var keysList = GetKeysAsListOfObjects(keys); + if (keysList.Count == 0) + { + return; + } + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); - await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false); + await this.InternalDeleteBatchAsync(connection, keysList, cancellationToken).ConfigureAwait(false); } #endregion @@ -480,18 +510,12 @@ private Task DropTableAsync(SqliteConnection connection, string tableName, .ConfigureAwait(false); } - private IAsyncEnumerable InternalGetBatchAsync( + private IAsyncEnumerable InternalGetBatchAsync( SqliteConnection connection, - IEnumerable keys, + List keysList, GetRecordOptions? options, CancellationToken cancellationToken) { - Verify.NotNull(keys); - - var keysList = keys.Cast().ToList(); - - Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero."); - var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keysList) { TableName = this._dataTableName @@ -583,6 +607,11 @@ private IAsyncEnumerable InternalUpsertBatchAsync(SqliteConnection c OperationName, () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + if (storageModels.Count == 0) + { + return AsyncEnumerable.Empty(); + } + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); var condition = new SqliteWhereInCondition(this._propertyReader.KeyPropertyStoragePropertyName, keys); @@ -645,24 +674,16 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( private Task InternalDeleteAsync(SqliteConnection connection, TKey key, CancellationToken cancellationToken) { - Verify.NotNull(key); - var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); return this.InternalDeleteBatchAsync(connection, condition, cancellationToken); } - private Task InternalDeleteBatchAsync(SqliteConnection connection, IEnumerable keys, CancellationToken cancellationToken) + private Task InternalDeleteBatchAsync(SqliteConnection connection, List keys, CancellationToken cancellationToken) { - Verify.NotNull(keys); - - var keysList = keys.Cast().ToList(); - - Verify.True(keysList.Count > 0, "Number of provided keys should be greater than zero."); - var condition = new SqliteWhereInCondition( this._propertyReader.KeyPropertyStoragePropertyName, - keysList); + keys); return this.InternalDeleteBatchAsync(connection, condition, cancellationToken); } @@ -811,4 +832,11 @@ private static string GetVectorTableName( } #endregion + + private static List GetKeysAsListOfObjects(IEnumerable keys) + { + Verify.NotNull(keys); + + return keys.Cast().ToList(); + } } diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs new file mode 100644 index 000000000000..21893736060e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteBatchConformanceTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.CRUD; +using Xunit; + +namespace SqliteIntegrationTests.CRUD; + +public class SqliteBatchConformanceTests_string(SqliteSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} + +public class SqliteBatchConformanceTests_ulong(SqliteSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs index 402c16b02621..b6fd9beaa96d 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/CRUD/SqliteRecordConformanceTests.cs @@ -6,7 +6,18 @@ namespace SqliteIntegrationTests.CRUD; -public class SqliteRecordConformanceTests(SqliteSimpleModelFixture fixture) - : RecordConformanceTests(fixture), IClassFixture +public class SqliteRecordConformanceTests_string(SqliteSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> { } + +public class SqliteRecordConformanceTests_ulong(SqliteSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> +{ + public override async Task GetAsyncThrowsArgumentNullExceptionForNullKey() + { + // default(ulong) is a valid key + var result = await fixture.Collection.GetAsync(default); + Assert.Null(result); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs index 345f4005cbe6..4941ccd9005e 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteSimpleModelFixture.cs @@ -4,7 +4,8 @@ namespace SqliteIntegrationTests.Support; -public class SqliteSimpleModelFixture : SimpleModelFixture +public class SqliteSimpleModelFixture : SimpleModelFixture + where TKey : notnull { public override TestStore TestStore => SqliteTestStore.Instance; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs index 5a3d0d0081ea..ab8b1bb884bb 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/CRUD/RecordConformanceTests.cs @@ -10,7 +10,7 @@ namespace VectorDataSpecificationTests.CRUD; public class RecordConformanceTests(SimpleModelFixture fixture) where TKey : notnull { [ConditionalFact] - public async Task GetAsyncThrowsArgumentNullExceptionForNullKey() + public virtual async Task GetAsyncThrowsArgumentNullExceptionForNullKey() { ArgumentNullException ex = await Assert.ThrowsAsync(() => fixture.Collection.GetAsync(default!)); Assert.Equal("key", ex.ParamName); From 65855937de12d72623a9b49cedcf444e6ce0a383 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Fri, 28 Mar 2025 14:18:25 +0100 Subject: [PATCH 3/5] fix the build? --- .../SqliteVectorStoreRecordCollection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 79c0226d881c..ecc6a87a516a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -672,7 +672,7 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( } } - private Task InternalDeleteAsync(SqliteConnection connection, TKey key, CancellationToken cancellationToken) + private Task InternalDeleteAsync(SqliteConnection connection, TKey key, CancellationToken cancellationToken) where TKey : notnull { var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); From f2d9961d59ef7e4ab2493132f0773b01e09e9395 Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Wed, 2 Apr 2025 16:33:30 +0200 Subject: [PATCH 4/5] add storage property names that require all kinds of escaping and quoting --- .../Support/SimpleModelFixture.cs | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs index b5c688c01835..f6214a866b5f 100644 --- a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/SimpleModelFixture.cs @@ -45,16 +45,28 @@ protected override VectorStoreRecordDefinition GetRecordDefinition() { Properties = [ - new VectorStoreRecordKeyProperty(nameof(SimpleModel.Id), typeof(TKey)), + new VectorStoreRecordKeyProperty(nameof(SimpleModel.Id), typeof(TKey)) + { + StoragePropertyName = "i d" // intentionally with a space + }, new VectorStoreRecordVectorProperty(nameof(SimpleModel.Floats), typeof(ReadOnlyMemory?)) { Dimensions = SimpleModel.DimensionCount, DistanceFunction = this.DistanceFunction, - IndexKind = this.IndexKind + IndexKind = this.IndexKind, + StoragePropertyName = "embed\"ding" // intentionally with quotes }, + new VectorStoreRecordDataProperty(nameof(SimpleModel.Number), typeof(int)) + { + IsFilterable = true, + StoragePropertyName = "num'ber" // intentionally with a single quote - new VectorStoreRecordDataProperty(nameof(SimpleModel.Number), typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty(nameof(SimpleModel.Text), typeof(string)) { IsFilterable = true }, + }, + new VectorStoreRecordDataProperty(nameof(SimpleModel.Text), typeof(string)) + { + IsFilterable = true, + StoragePropertyName = "te]xt" // intentionally with a character that requires escaping for Sql Server + } ] }; } From c46828891fcd595a90a9c94fef3f2814f778d37d Mon Sep 17 00:00:00 2001 From: Adam Sitnik Date: Thu, 3 Apr 2025 11:50:53 +0200 Subject: [PATCH 5/5] escape all Sql Server queries --- .../SqlServerCommandBuilder.cs | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs index aebcf8fe8787..df0656395dbb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs @@ -34,20 +34,24 @@ internal static SqlCommand CreateTable( sb.AppendTableName(schema, tableName); sb.AppendLine(" ("); string keyColumnName = GetColumnName(keyProperty); - sb.AppendFormat("[{0}] {1} NOT NULL,", keyColumnName, Map(keyProperty)); + sb.AppendEscaped(keyColumnName); + sb.AppendFormat(" {0} NOT NULL,", Map(keyProperty)); sb.AppendLine(); for (int i = 0; i < dataProperties.Count; i++) { - sb.AppendFormat("[{0}] {1},", GetColumnName(dataProperties[i]), Map(dataProperties[i])); + sb.AppendEscaped(GetColumnName(dataProperties[i])); + sb.AppendFormat(" {0},", Map(dataProperties[i])); sb.AppendLine(); } for (int i = 0; i < vectorProperties.Count; i++) { - sb.AppendFormat("[{0}] VECTOR({1}),", GetColumnName(vectorProperties[i]), vectorProperties[i].Dimensions); + sb.AppendEscaped(GetColumnName(vectorProperties[i])); + sb.AppendFormat(" VECTOR({0}),", vectorProperties[i].Dimensions); sb.AppendLine(); } - sb.AppendFormat("PRIMARY KEY ([{0}])", keyColumnName); - sb.AppendLine(); + sb.Append("PRIMARY KEY ("); + sb.AppendEscaped(keyColumnName); + sb.AppendLine(")"); sb.AppendLine(");"); // end the table definition foreach (var dataProperty in dataProperties) @@ -57,8 +61,9 @@ internal static SqlCommand CreateTable( sb.AppendFormat("CREATE INDEX "); sb.AppendIndexName(tableName, GetColumnName(dataProperty)); sb.AppendFormat(" ON ").AppendTableName(schema, tableName); - sb.AppendFormat("([{0}]);", GetColumnName(dataProperty)); - sb.AppendLine(); + sb.Append('('); + sb.AppendEscaped(GetColumnName(dataProperty)); + sb.AppendLine(");"); } } @@ -140,14 +145,18 @@ internal static SqlCommand MergeIntoSingle( sb.Append(") AS s ("); sb.AppendColumnNames(properties); sb.AppendLine(")"); - sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.Append("ON (t.").AppendEscaped(GetColumnName(keyProperty)); + sb.Append(" = s.").AppendEscaped(GetColumnName(keyProperty)); + sb.AppendLine(")"); sb.AppendLine("WHEN MATCHED THEN"); sb.Append("UPDATE SET "); foreach (VectorStoreRecordProperty property in properties) { if (property != keyProperty) // don't update the key { - sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + sb.Append("t.").AppendEscaped(GetColumnName(property)); + sb.Append(" = s.").AppendEscaped(GetColumnName(property)); + sb.Append(','); } } --sb.Length; // remove the last comma @@ -161,7 +170,7 @@ internal static SqlCommand MergeIntoSingle( sb.Append("VALUES ("); sb.AppendColumnNames(properties, prefix: "s."); sb.AppendLine(")"); - sb.AppendFormat("OUTPUT inserted.[{0}];", GetColumnName(keyProperty)); + sb.Append("OUTPUT inserted.").AppendEscaped(GetColumnName(keyProperty)).Append(';'); command.CommandText = sb.ToString(); return command; @@ -208,14 +217,18 @@ internal static bool MergeIntoMany( sb.Append(") AS s ("); // s stands for source sb.AppendColumnNames(properties); sb.AppendLine(")"); - sb.AppendFormat("ON (t.[{0}] = s.[{0}])", GetColumnName(keyProperty)).AppendLine(); + sb.Append("ON (t.").AppendEscaped(GetColumnName(keyProperty)); + sb.Append(" = s.").AppendEscaped(GetColumnName(keyProperty)); + sb.AppendLine(")"); sb.AppendLine("WHEN MATCHED THEN"); sb.Append("UPDATE SET "); foreach (VectorStoreRecordProperty property in properties) { if (property != keyProperty) // don't update the key { - sb.AppendFormat("t.[{0}] = s.[{0}],", GetColumnName(property)); + sb.Append("t.").AppendEscaped(GetColumnName(property)); + sb.Append(" = s.").AppendEscaped(GetColumnName(property)); + sb.Append(','); } } --sb.Length; // remove the last comma @@ -228,8 +241,8 @@ internal static bool MergeIntoMany( sb.Append("VALUES ("); sb.AppendColumnNames(properties, prefix: "s."); sb.AppendLine(")"); - sb.AppendFormat("OUTPUT inserted.[{0}] INTO @InsertedKeys (KeyColumn);", GetColumnName(keyProperty)); - sb.AppendLine(); + sb.Append("OUTPUT inserted.").AppendEscaped(GetColumnName(keyProperty)); + sb.AppendLine(" INTO @InsertedKeys (KeyColumn);"); // The SELECT statement returns the keys of the inserted rows. sb.Append("SELECT KeyColumn FROM @InsertedKeys;"); @@ -248,7 +261,7 @@ internal static SqlCommand DeleteSingle( StringBuilder sb = new(100); sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); - sb.AppendFormat(" WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.Append(" WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" = "); sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); command.AddParameter(keyProperty, keyParamName, key); @@ -263,7 +276,7 @@ internal static bool DeleteMany( StringBuilder sb = new(100); sb.Append("DELETE FROM "); sb.AppendTableName(schema, tableName); - sb.AppendFormat(" WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.Append(" WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" IN ("); sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause @@ -293,7 +306,7 @@ internal static SqlCommand SelectSingle( sb.Append("FROM "); sb.AppendTableName(schema, collectionName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] = ", GetColumnName(keyProperty)); + sb.Append("WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" = "); sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); command.AddParameter(keyProperty, keyParamName, key); @@ -315,7 +328,7 @@ internal static bool SelectMany( sb.Append("FROM "); sb.AppendTableName(schema, tableName); sb.AppendLine(); - sb.AppendFormat("WHERE [{0}] IN (", GetColumnName(keyProperty)); + sb.Append("WHERE ").AppendEscaped(GetColumnName(keyProperty)).Append(" IN ("); sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); sb.Append(')'); // close the IN clause @@ -346,8 +359,9 @@ internal static SqlCommand SelectVector( sb.AppendFormat("SELECT "); sb.AppendColumnNames(properties, includeVectors: options.IncludeVectors); sb.AppendLine(","); - sb.AppendFormat("VECTOR_DISTANCE('{0}', {1}, CAST(@vector AS VECTOR({2}))) AS [score]", - distanceMetric, GetColumnName(vectorProperty), vector.Length); + sb.AppendFormat("VECTOR_DISTANCE('{0}', ", distanceMetric); + sb.AppendEscaped(GetColumnName(vectorProperty)); + sb.AppendFormat(", CAST(@vector AS VECTOR({0}))) AS [score]", vector.Length); sb.AppendLine(); sb.Append("FROM "); sb.AppendTableName(schema, tableName); @@ -411,25 +425,29 @@ internal static StringBuilder AppendParameterName(this StringBuilder sb, VectorS return sb; } - internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) + internal static StringBuilder AppendEscaped(this StringBuilder sb, string value) { - // If the column name contains a ], then escape it by doubling it. + // If the value contains a ], then escape it by doubling it. // "Name with [brackets]" becomes [Name with [brackets]]]. sb.Append('['); int index = sb.Length; // store the index, so we replace ] only for the appended part + sb.Append(value); + sb.Replace("]", "]]", index, value.Length); // replace the ] for value (escape it) + sb.Append(']'); + return sb; + } + + internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) + { if (!string.IsNullOrEmpty(schema)) { - sb.Append(schema); - sb.Replace("]", "]]", index, schema!.Length); // replace the ] for schema - sb.Append("].["); - index = sb.Length; + sb.AppendEscaped(schema!); + sb.Append('.'); } - sb.Append(tableName); - sb.Replace("]", "]]", index, tableName.Length); - sb.Append(']'); + sb.AppendEscaped(tableName); return sb; } @@ -452,7 +470,8 @@ private static StringBuilder AppendColumnNames(this StringBuilder sb, sb.Append(prefix); } // Use square brackets to escape column names. - sb.AppendFormat("[{0}],", GetColumnName(property)); + sb.AppendEscaped(GetColumnName(property)); + sb.Append(','); any = true; }