Skip to content

Commit ae446fb

Browse files
authored
Augment AIJsonUtilities.CreateJsonSchema for more types and annotations (#6540)
* Augment AIJsonUtilities.CreateJsonSchema for more types and annotations * Stop suppressing existing format handling, and allow most annotations on netfx
1 parent 3d2ddda commit ae446fb

File tree

6 files changed

+866
-53
lines changed

6 files changed

+866
-53
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
<ItemGroup Condition="'$(TargetFramework)' == 'net462'">
3838
<Reference Include="System.Net.Http" />
39+
<Reference Include="System.ComponentModel.DataAnnotations" />
3940
</ItemGroup>
4041

4142
</Project>

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.Create.cs

Lines changed: 265 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
using System;
55
using System.ComponentModel;
6+
#if NET || NETFRAMEWORK
7+
using System.ComponentModel.DataAnnotations;
8+
#endif
69
using System.Diagnostics;
710
using System.Diagnostics.CodeAnalysis;
811
using System.Reflection;
@@ -14,11 +17,12 @@
1417
using System.Threading;
1518
using Microsoft.Shared.Diagnostics;
1619

17-
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
1820
#pragma warning disable S107 // Methods should not have too many parameters
21+
#pragma warning disable S109 // Magic numbers should not be used
1922
#pragma warning disable S1075 // URIs should not be hardcoded
23+
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
24+
#pragma warning disable S1199 // Nested block
2025
#pragma warning disable SA1118 // Parameter should not span multiple lines
21-
#pragma warning disable S109 // Magic numbers should not be used
2226

2327
namespace Microsoft.Extensions.AI;
2428

@@ -38,14 +42,25 @@ public static partial class AIJsonUtilities
3842
private const string AdditionalPropertiesPropertyName = "additionalProperties";
3943
private const string DefaultPropertyName = "default";
4044
private const string RefPropertyName = "$ref";
45+
#if NET || NETFRAMEWORK
46+
private const string FormatPropertyName = "format";
47+
private const string MinLengthStringPropertyName = "minLength";
48+
private const string MaxLengthStringPropertyName = "maxLength";
49+
private const string MinLengthCollectionPropertyName = "minItems";
50+
private const string MaxLengthCollectionPropertyName = "maxItems";
51+
private const string MinRangePropertyName = "minimum";
52+
private const string MaxRangePropertyName = "maximum";
53+
#endif
54+
#if NET
55+
private const string ContentEncodingPropertyName = "contentEncoding";
56+
private const string ContentMediaTypePropertyName = "contentMediaType";
57+
private const string MinExclusiveRangePropertyName = "exclusiveMinimum";
58+
private const string MaxExclusiveRangePropertyName = "exclusiveMaximum";
59+
#endif
4160

4261
/// <summary>The uri used when populating the $schema keyword in created schemas.</summary>
4362
private const string SchemaKeywordUri = "https://json-schema.org/draft/2020-12/schema";
4463

45-
// List of keywords used by JsonSchemaExporter but explicitly disallowed by some AI vendors.
46-
// cf. https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
47-
private static readonly string[] _schemaKeywordsDisallowedByAIVendors = ["minLength", "maxLength", "pattern", "format"];
48-
4964
/// <summary>
5065
/// Determines a JSON schema for the provided method.
5166
/// </summary>
@@ -280,12 +295,6 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
280295
objSchema.InsertAtStart(TypePropertyName, new JsonArray { (JsonNode)"string", (JsonNode)"null" });
281296
}
282297

283-
// Filter potentially disallowed keywords.
284-
foreach (string keyword in _schemaKeywordsDisallowedByAIVendors)
285-
{
286-
_ = objSchema.Remove(keyword);
287-
}
288-
289298
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
290299
// schemas with "type": [...], and only understand "type" being a single value.
291300
// In certain configurations STJ represents .NET numeric types as ["string", "number"], which will then lead to an error.
@@ -318,6 +327,8 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext schemaExporterContext, Js
318327
ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri);
319328
}
320329

330+
ApplyDataAnnotations(parameterName, ref schema, ctx);
331+
321332
// Finally, apply any user-defined transformations if specified.
322333
if (inferenceOptions.TransformSchemaNode is { } transformer)
323334
{
@@ -345,6 +356,248 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
345356
return obj;
346357
}
347358
}
359+
360+
void ApplyDataAnnotations(string? parameterName, ref JsonNode schema, AIJsonSchemaCreateContext ctx)
361+
{
362+
if (ctx.GetCustomAttribute<DisplayNameAttribute>() is { } displayNameAttribute)
363+
{
364+
ConvertSchemaToObject(ref schema)[TitlePropertyName] ??= displayNameAttribute.DisplayName;
365+
}
366+
367+
#if NET || NETFRAMEWORK
368+
if (ctx.GetCustomAttribute<EmailAddressAttribute>() is { } emailAttribute)
369+
{
370+
ConvertSchemaToObject(ref schema)[FormatPropertyName] ??= "email";
371+
}
372+
373+
if (ctx.GetCustomAttribute<UrlAttribute>() is { } urlAttribute)
374+
{
375+
ConvertSchemaToObject(ref schema)[FormatPropertyName] ??= "uri";
376+
}
377+
378+
if (ctx.GetCustomAttribute<RegularExpressionAttribute>() is { } regexAttribute)
379+
{
380+
ConvertSchemaToObject(ref schema)[PatternPropertyName] ??= regexAttribute.Pattern;
381+
}
382+
383+
if (ctx.GetCustomAttribute<StringLengthAttribute>() is { } stringLengthAttribute)
384+
{
385+
JsonObject obj = ConvertSchemaToObject(ref schema);
386+
387+
if (stringLengthAttribute.MinimumLength > 0)
388+
{
389+
obj[MinLengthStringPropertyName] ??= stringLengthAttribute.MinimumLength;
390+
}
391+
392+
obj[MaxLengthStringPropertyName] ??= stringLengthAttribute.MaximumLength;
393+
}
394+
395+
if (ctx.GetCustomAttribute<MinLengthAttribute>() is { } minLengthAttribute)
396+
{
397+
JsonObject obj = ConvertSchemaToObject(ref schema);
398+
if (obj[TypePropertyName] is JsonNode typeNode && typeNode.GetValueKind() is JsonValueKind.String && typeNode.GetValue<string>() is "string")
399+
{
400+
obj[MinLengthStringPropertyName] ??= minLengthAttribute.Length;
401+
}
402+
else
403+
{
404+
obj[MinLengthCollectionPropertyName] ??= minLengthAttribute.Length;
405+
}
406+
}
407+
408+
if (ctx.GetCustomAttribute<MaxLengthAttribute>() is { } maxLengthAttribute)
409+
{
410+
JsonObject obj = ConvertSchemaToObject(ref schema);
411+
if (obj[TypePropertyName] is JsonNode typeNode && typeNode.GetValueKind() is JsonValueKind.String && typeNode.GetValue<string>() is "string")
412+
{
413+
obj[MaxLengthStringPropertyName] ??= maxLengthAttribute.Length;
414+
}
415+
else
416+
{
417+
obj[MaxLengthCollectionPropertyName] ??= maxLengthAttribute.Length;
418+
}
419+
}
420+
421+
if (ctx.GetCustomAttribute<RangeAttribute>() is { } rangeAttribute)
422+
{
423+
JsonObject obj = ConvertSchemaToObject(ref schema);
424+
425+
JsonNode? minNode = null;
426+
JsonNode? maxNode = null;
427+
switch (rangeAttribute.Minimum)
428+
{
429+
case int minInt32 when rangeAttribute.Maximum is int maxInt32:
430+
maxNode = maxInt32;
431+
if (
432+
#if NET
433+
!rangeAttribute.MinimumIsExclusive ||
434+
#endif
435+
minInt32 > 0)
436+
{
437+
minNode = minInt32;
438+
}
439+
440+
break;
441+
442+
case double minDouble when rangeAttribute.Maximum is double maxDouble:
443+
maxNode = maxDouble;
444+
if (
445+
#if NET
446+
!rangeAttribute.MinimumIsExclusive ||
447+
#endif
448+
minDouble > 0)
449+
{
450+
minNode = minDouble;
451+
}
452+
453+
break;
454+
455+
case string minString when rangeAttribute.Maximum is string maxString:
456+
maxNode = maxString;
457+
minNode = minString;
458+
break;
459+
}
460+
461+
if (minNode is not null)
462+
{
463+
#if NET
464+
if (rangeAttribute.MinimumIsExclusive)
465+
{
466+
obj[MinExclusiveRangePropertyName] ??= minNode;
467+
}
468+
else
469+
#endif
470+
{
471+
obj[MinRangePropertyName] ??= minNode;
472+
}
473+
}
474+
475+
if (maxNode is not null)
476+
{
477+
#if NET
478+
if (rangeAttribute.MaximumIsExclusive)
479+
{
480+
obj[MaxExclusiveRangePropertyName] ??= maxNode;
481+
}
482+
else
483+
#endif
484+
{
485+
obj[MaxRangePropertyName] ??= maxNode;
486+
}
487+
}
488+
}
489+
#endif
490+
491+
#if NET
492+
if (ctx.GetCustomAttribute<Base64StringAttribute>() is { } base64Attribute)
493+
{
494+
ConvertSchemaToObject(ref schema)[ContentEncodingPropertyName] ??= "base64";
495+
}
496+
497+
if (ctx.GetCustomAttribute<LengthAttribute>() is { } lengthAttribute)
498+
{
499+
JsonObject obj = ConvertSchemaToObject(ref schema);
500+
501+
if (obj[TypePropertyName] is JsonNode typeNode && typeNode.GetValueKind() is JsonValueKind.String && typeNode.GetValue<string>() is "string")
502+
{
503+
if (lengthAttribute.MinimumLength > 0)
504+
{
505+
obj[MinLengthStringPropertyName] ??= lengthAttribute.MinimumLength;
506+
}
507+
508+
obj[MaxLengthStringPropertyName] ??= lengthAttribute.MaximumLength;
509+
}
510+
else
511+
{
512+
if (lengthAttribute.MinimumLength > 0)
513+
{
514+
obj[MinLengthCollectionPropertyName] ??= lengthAttribute.MinimumLength;
515+
}
516+
517+
obj[MaxLengthCollectionPropertyName] ??= lengthAttribute.MaximumLength;
518+
}
519+
}
520+
521+
if (ctx.GetCustomAttribute<AllowedValuesAttribute>() is { } allowedValuesAttribute)
522+
{
523+
JsonObject obj = ConvertSchemaToObject(ref schema);
524+
if (!obj.ContainsKey(EnumPropertyName))
525+
{
526+
if (CreateJsonArray(allowedValuesAttribute.Values, serializerOptions) is { Count: > 0 } enumArray)
527+
{
528+
obj[EnumPropertyName] = enumArray;
529+
}
530+
}
531+
}
532+
533+
if (ctx.GetCustomAttribute<DeniedValuesAttribute>() is { } deniedValuesAttribute)
534+
{
535+
JsonObject obj = ConvertSchemaToObject(ref schema);
536+
537+
JsonNode? notNode = obj[NotPropertyName];
538+
if (notNode is null or JsonObject)
539+
{
540+
JsonObject notObj =
541+
notNode as JsonObject ??
542+
(JsonObject)(obj[NotPropertyName] = new JsonObject());
543+
544+
if (notObj[EnumPropertyName] is null)
545+
{
546+
if (CreateJsonArray(deniedValuesAttribute.Values, serializerOptions) is { Count: > 0 } enumArray)
547+
{
548+
notObj[EnumPropertyName] = enumArray;
549+
}
550+
}
551+
}
552+
}
553+
554+
static JsonArray CreateJsonArray(object?[] values, JsonSerializerOptions serializerOptions)
555+
{
556+
JsonArray enumArray = new();
557+
foreach (object? allowedValue in values)
558+
{
559+
if (allowedValue is not null && JsonSerializer.SerializeToNode(allowedValue, serializerOptions.GetTypeInfo(allowedValue.GetType())) is { } valueNode)
560+
{
561+
enumArray.Add(valueNode);
562+
}
563+
}
564+
565+
return enumArray;
566+
}
567+
568+
if (ctx.GetCustomAttribute<DataTypeAttribute>() is { } dataTypeAttribute)
569+
{
570+
JsonObject obj = ConvertSchemaToObject(ref schema);
571+
switch (dataTypeAttribute.DataType)
572+
{
573+
case DataType.DateTime:
574+
obj[FormatPropertyName] ??= "date-time";
575+
break;
576+
577+
case DataType.Date:
578+
obj[FormatPropertyName] ??= "date";
579+
break;
580+
581+
case DataType.Time:
582+
obj[FormatPropertyName] ??= "time";
583+
break;
584+
585+
case DataType.EmailAddress:
586+
obj[FormatPropertyName] ??= "email";
587+
break;
588+
589+
case DataType.Url:
590+
obj[FormatPropertyName] ??= "uri";
591+
break;
592+
593+
case DataType.ImageUrl:
594+
obj[FormatPropertyName] ??= "uri";
595+
obj[ContentMediaTypePropertyName] ??= "image/*";
596+
break;
597+
}
598+
}
599+
#endif
600+
}
348601
}
349602
}
350603

0 commit comments

Comments
 (0)