Skip to content

Commit 9ef845e

Browse files
committed
Introduce a LINQ filter preprocessor in MEVD
1 parent 1027cc6 commit 9ef845e

File tree

14 files changed

+204
-285
lines changed

14 files changed

+204
-285
lines changed

dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs

+13-51
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
using System.Diagnostics.CodeAnalysis;
88
using System.Linq;
99
using System.Linq.Expressions;
10-
using System.Reflection;
11-
using System.Runtime.CompilerServices;
1210
using System.Text;
1311
using Microsoft.Extensions.VectorData.ConnectorSupport;
12+
using Microsoft.Extensions.VectorData.ConnectorSupport.Filter;
1413

1514
namespace Microsoft.SemanticKernel.Connectors.AzureAISearch;
1615

@@ -32,7 +31,11 @@ internal string Translate(LambdaExpression lambdaExpression, VectorStoreRecordMo
3231
Debug.Assert(lambdaExpression.Parameters.Count == 1);
3332
this._recordParameter = lambdaExpression.Parameters[0];
3433

35-
this.Translate(lambdaExpression.Body);
34+
var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true };
35+
var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body);
36+
37+
this.Translate(preprocessedExpression);
38+
3639
return this._filter.ToString();
3740
}
3841

@@ -139,20 +142,13 @@ private void GenerateLiteral(object? value)
139142

140143
private void TranslateMember(MemberExpression memberExpression)
141144
{
142-
switch (memberExpression)
145+
if (this.TryBindProperty(memberExpression, out var property))
143146
{
144-
case var _ when this.TryBindProperty(memberExpression, out var property):
145-
this._filter.Append(property.StorageName); // TODO: Escape
146-
return;
147-
148-
// Identify captured lambda variables, inline them as constants
149-
case var _ when TryGetCapturedValue(memberExpression, out var capturedValue):
150-
this.GenerateLiteral(capturedValue);
151-
return;
152-
153-
default:
154-
throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported");
147+
this._filter.Append(property.StorageName); // TODO: Escape
148+
return;
155149
}
150+
151+
throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported");
156152
}
157153

158154
private void TranslateMethodCall(MethodCallExpression methodCall)
@@ -207,7 +203,7 @@ private void TranslateContains(Expression source, Expression item)
207203

208204
for (var i = 0; i < newArray.Expressions.Count; i++)
209205
{
210-
if (!TryGetConstant(newArray.Expressions[i], out var elementValue))
206+
if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue })
211207
{
212208
throw new NotSupportedException("Invalid element in array");
213209
}
@@ -223,9 +219,7 @@ private void TranslateContains(Expression source, Expression item)
223219
ProcessInlineEnumerable(elements, item);
224220
return;
225221

226-
// Contains over captured enumerable (we inline)
227-
case var _ when TryGetConstant(source, out var constantEnumerable)
228-
&& constantEnumerable is IEnumerable enumerable and not string:
222+
case ConstantExpression { Value: IEnumerable enumerable and not string }:
229223
ProcessInlineEnumerable(enumerable, item);
230224
return;
231225

@@ -372,36 +366,4 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Vect
372366

373367
return true;
374368
}
375-
376-
private static bool TryGetCapturedValue(Expression expression, out object? capturedValue)
377-
{
378-
if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo }
379-
&& constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
380-
&& Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
381-
{
382-
capturedValue = fieldInfo.GetValue(constant.Value);
383-
return true;
384-
}
385-
386-
capturedValue = null;
387-
return false;
388-
}
389-
390-
private static bool TryGetConstant(Expression expression, out object? constantValue)
391-
{
392-
switch (expression)
393-
{
394-
case ConstantExpression { Value: var v }:
395-
constantValue = v;
396-
return true;
397-
398-
case var _ when TryGetCapturedValue(expression, out var capturedValue):
399-
constantValue = capturedValue;
400-
return true;
401-
402-
default:
403-
constantValue = null;
404-
return false;
405-
}
406-
}
407369
}

dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs

+11-30
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
using System.Diagnostics.CodeAnalysis;
88
using System.Linq;
99
using System.Linq.Expressions;
10-
using System.Reflection;
11-
using System.Runtime.CompilerServices;
1210
using Microsoft.Extensions.VectorData.ConnectorSupport;
11+
using Microsoft.Extensions.VectorData.ConnectorSupport.Filter;
1312
using MongoDB.Bson;
1413

1514
namespace Microsoft.SemanticKernel.Connectors.MongoDB;
@@ -28,7 +27,10 @@ internal BsonDocument Translate(LambdaExpression lambdaExpression, VectorStoreRe
2827
Debug.Assert(lambdaExpression.Parameters.Count == 1);
2928
this._recordParameter = lambdaExpression.Parameters[0];
3029

31-
return this.Translate(lambdaExpression.Body);
30+
var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = true };
31+
var preprocessedExpression = preprocessor.Visit(lambdaExpression.Body);
32+
33+
return this.Translate(preprocessedExpression);
3234
}
3335

3436
private BsonDocument Translate(Expression? node)
@@ -57,9 +59,10 @@ or ExpressionType.LessThan or ExpressionType.LessThanOrEqual
5759
};
5860

5961
private BsonDocument TranslateEqualityComparison(BinaryExpression binary)
60-
=> (this.TryBindProperty(binary.Left, out var property) && TryGetConstant(binary.Right, out var value))
61-
|| (this.TryBindProperty(binary.Right, out property) && TryGetConstant(binary.Left, out value))
62-
? this.GenerateEqualityComparison(property, value, binary.NodeType)
62+
=> this.TryBindProperty(binary.Left, out var property) && binary.Right is ConstantExpression { Value: var rightConstant }
63+
? this.GenerateEqualityComparison(property, rightConstant, binary.NodeType)
64+
: this.TryBindProperty(binary.Right, out property) && binary.Left is ConstantExpression { Value: var leftConstant }
65+
? this.GenerateEqualityComparison(property, leftConstant, binary.NodeType)
6366
: throw new NotSupportedException("Invalid equality/comparison");
6467

6568
private BsonDocument GenerateEqualityComparison(VectorStoreRecordPropertyModel property, object? value, ExpressionType nodeType)
@@ -184,7 +187,7 @@ private BsonDocument TranslateContains(Expression source, Expression item)
184187

185188
for (var i = 0; i < newArray.Expressions.Count; i++)
186189
{
187-
if (!TryGetConstant(newArray.Expressions[i], out var elementValue))
190+
if (newArray.Expressions[i] is not ConstantExpression { Value: var elementValue })
188191
{
189192
throw new NotSupportedException("Invalid element in array");
190193
}
@@ -195,8 +198,7 @@ private BsonDocument TranslateContains(Expression source, Expression item)
195198
return ProcessInlineEnumerable(elements, item);
196199

197200
// Contains over captured enumerable (we inline)
198-
case var _ when TryGetConstant(source, out var constantEnumerable)
199-
&& constantEnumerable is IEnumerable enumerable and not string:
201+
case ConstantExpression { Value: IEnumerable enumerable and not string }:
200202
return ProcessInlineEnumerable(enumerable, item);
201203

202204
default:
@@ -265,25 +267,4 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Vect
265267

266268
return true;
267269
}
268-
269-
private static bool TryGetConstant(Expression expression, out object? constantValue)
270-
{
271-
switch (expression)
272-
{
273-
case ConstantExpression { Value: var v }:
274-
constantValue = v;
275-
return true;
276-
277-
// This identifies compiler-generated closure types which contain captured variables.
278-
case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo }
279-
when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
280-
&& Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true):
281-
constantValue = fieldInfo.GetValue(constant.Value);
282-
return true;
283-
284-
default:
285-
constantValue = null;
286-
return false;
287-
}
288-
}
289270
}

dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Runtime.CompilerServices;
1111
using System.Text;
1212
using Microsoft.Extensions.VectorData.ConnectorSupport;
13+
using Microsoft.Extensions.VectorData.ConnectorSupport.Filter;
1314

1415
namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL;
1516

@@ -30,7 +31,11 @@ internal class AzureCosmosDBNoSqlFilterTranslator
3031
Debug.Assert(lambdaExpression.Parameters.Count == 1);
3132
this._recordParameter = lambdaExpression.Parameters[0];
3233

34+
var preprocessor = new FilterTranslationPreprocessor { InlineCapturedVariables = false };
35+
var preprocessedExpression = preprocessor.Visit(lambdaExpression);
36+
3337
this.Translate(lambdaExpression.Body);
38+
3439
return (this._sql.ToString(), this._parameters);
3540
}
3641

dotnet/src/Connectors/Connectors.Memory.Common/SqlFilterTranslator.cs

+19-36
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
using System.Diagnostics.CodeAnalysis;
77
using System.Linq;
88
using System.Linq.Expressions;
9-
using System.Reflection;
10-
using System.Runtime.CompilerServices;
119
using System.Text;
1210
using Microsoft.Extensions.VectorData.ConnectorSupport;
11+
using Microsoft.Extensions.VectorData.ConnectorSupport.Filter;
1312

1413
namespace Microsoft.SemanticKernel.Connectors;
1514

@@ -43,7 +42,10 @@ internal void Translate(bool appendWhere)
4342
this._sql.Append("WHERE ");
4443
}
4544

46-
this.Translate(this._lambdaExpression.Body, isSearchCondition: true);
45+
var preprocessor = new FilterTranslationPreprocessor { TransformCapturedVariablesToQueryParameterExpressions = true };
46+
var preprocessedExpression = preprocessor.Visit(this._lambdaExpression.Body);
47+
48+
this.Translate(preprocessedExpression, isSearchCondition: true);
4749
}
4850

4951
protected void Translate(Expression? node, bool isSearchCondition = false)
@@ -58,6 +60,10 @@ protected void Translate(Expression? node, bool isSearchCondition = false)
5860
this.TranslateConstant(constant.Value);
5961
return;
6062

63+
case QueryParameterExpression { Name: var name, Value: var value }:
64+
this.TranslateQueryParameter(name, value);
65+
return;
66+
6167
case MemberExpression member:
6268
this.TranslateMember(member, isSearchCondition);
6369
return;
@@ -127,8 +133,7 @@ protected void TranslateBinary(BinaryExpression binary)
127133
this._sql.Append(')');
128134

129135
static bool IsNull(Expression expression)
130-
=> expression is ConstantExpression { Value: null }
131-
|| (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null);
136+
=> expression is ConstantExpression { Value: null } or QueryParameterExpression { Value: null };
132137
}
133138

134139
protected virtual void TranslateConstant(object? value)
@@ -175,25 +180,19 @@ protected virtual void TranslateConstant(object? value)
175180

176181
private void TranslateMember(MemberExpression memberExpression, bool isSearchCondition)
177182
{
178-
switch (memberExpression)
183+
if (this.TryBindProperty(memberExpression, out var property))
179184
{
180-
case var _ when this.TryBindProperty(memberExpression, out var property):
181-
this.GenerateColumn(property.StorageName, isSearchCondition);
182-
return;
183-
184-
case var _ when TryGetCapturedValue(memberExpression, out var name, out var value):
185-
this.TranslateCapturedVariable(name, value);
186-
return;
187-
188-
default:
189-
throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported");
185+
this.GenerateColumn(property.StorageName, isSearchCondition);
186+
return;
190187
}
188+
189+
throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported");
191190
}
192191

193192
protected virtual void GenerateColumn(string column, bool isSearchCondition = false)
194193
=> this._sql.Append('"').Append(column.Replace("\"", "\"\"")).Append('"');
195194

196-
protected abstract void TranslateCapturedVariable(string name, object? capturedValue);
195+
protected abstract void TranslateQueryParameter(string name, object? value);
197196

198197
private void TranslateMethodCall(MethodCallExpression methodCall, bool isSearchCondition = false)
199198
{
@@ -262,8 +261,8 @@ private void TranslateContains(Expression source, Expression item)
262261
return;
263262

264263
// Contains over captured array (r => arrayLocalVariable.Contains(r.String))
265-
case var _ when TryGetCapturedValue(source, out _, out var value):
266-
this.TranslateContainsOverCapturedArray(source, item, value);
264+
case QueryParameterExpression { Value: var value }:
265+
this.TranslateContainsOverParameterizedArray(source, item, value);
267266
return;
268267

269268
default:
@@ -273,7 +272,7 @@ private void TranslateContains(Expression source, Expression item)
273272

274273
protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item);
275274

276-
protected abstract void TranslateContainsOverCapturedArray(Expression source, Expression item, object? value);
275+
protected abstract void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value);
277276

278277
private void TranslateUnary(UnaryExpression unary, bool isSearchCondition)
279278
{
@@ -351,20 +350,4 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Vect
351350

352351
return true;
353352
}
354-
355-
private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value)
356-
{
357-
if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo }
358-
&& constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate)
359-
&& Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
360-
{
361-
name = fieldInfo.Name;
362-
value = fieldInfo.GetValue(constant.Value);
363-
return true;
364-
}
365-
366-
name = null;
367-
value = null;
368-
return false;
369-
}
370353
}

0 commit comments

Comments
 (0)