Skip to content

added SchemaConverter interface for controlling Schema generation #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.sashirestela.openai.common.function;

import io.github.sashirestela.openai.support.JsonSchemaUtil;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
Expand All @@ -16,4 +17,7 @@ public class FunctionDef {
@NonNull
private Class<? extends Functional> functionalClass;

@Builder.Default
private SchemaConverter schemaConverter = JsonSchemaUtil.defaultConverter;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.github.sashirestela.openai.common.function;

import com.fasterxml.jackson.databind.JsonNode;

public interface SchemaConverter {

JsonNode convert(Class<?> c);

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.JsonNode;
import io.github.sashirestela.openai.common.function.FunctionDef;
import io.github.sashirestela.openai.support.JsonSchemaUtil;
import io.github.sashirestela.slimvalidator.constraints.Required;
import io.github.sashirestela.slimvalidator.constraints.Size;
import lombok.AllArgsConstructor;
Expand All @@ -26,7 +25,7 @@ public static Tool function(FunctionDef function) {
new ToolFunctionDef(
function.getName(),
function.getDescription(),
JsonSchemaUtil.classToJsonSchema(function.getFunctionalClass())));
function.getSchemaConverter().convert(function.getFunctionalClass())));
}

@AllArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.github.sashirestela.openai.support;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.victools.jsonschema.generator.*;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import io.github.sashirestela.openai.SimpleUncheckedException;
import io.github.sashirestela.openai.common.function.SchemaConverter;

import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS;

public class DefaultSchemaConverter implements SchemaConverter {

private final SchemaGenerator schemaGenerator;
private final ObjectMapper objectMapper;

public DefaultSchemaConverter() {
objectMapper = new ObjectMapper();
var jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
JacksonOption.RESPECT_JSONPROPERTY_ORDER);
var configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
OptionPreset.PLAIN_JSON)
.with(jacksonModule)
.without(Option.SCHEMA_VERSION_INDICATOR);
var config = configBuilder.build();
schemaGenerator = new SchemaGenerator(config);
}

@Override
public JsonNode convert(Class<?> clazz) {
JsonNode jsonSchema;
try {
jsonSchema = schemaGenerator.generateSchema(clazz);
if (jsonSchema.get("properties") == null) {
jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
}

} catch (Exception e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.", clazz.getName(),
e);
}
return jsonSchema;
}

}
Original file line number Diff line number Diff line change
@@ -1,45 +1,19 @@
package io.github.sashirestela.openai.support;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.victools.jsonschema.generator.Option;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import io.github.sashirestela.openai.SimpleUncheckedException;
import io.github.sashirestela.openai.common.function.SchemaConverter;

public class JsonSchemaUtil {

public static final SchemaConverter defaultConverter = new DefaultSchemaConverter();

public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}";
private static ObjectMapper objectMapper = new ObjectMapper();

private JsonSchemaUtil() {
}

public static JsonNode classToJsonSchema(Class<?> clazz) {
JsonNode jsonSchema = null;
try {
var jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
JacksonOption.RESPECT_JSONPROPERTY_ORDER);
var configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
OptionPreset.PLAIN_JSON)
.with(jacksonModule)
.without(Option.SCHEMA_VERSION_INDICATOR);
var config = configBuilder.build();
var generator = new SchemaGenerator(config);
jsonSchema = generator.generateSchema(clazz);
if (jsonSchema.get("properties") == null) {
jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
}

} catch (Exception e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.",
clazz.getName(), e);
}
return jsonSchema;
return defaultConverter.convert(clazz);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.github.sashirestela.openai.support;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.victools.jsonschema.generator.*;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import io.github.sashirestela.openai.SimpleUncheckedException;
import io.github.sashirestela.openai.common.function.SchemaConverter;

public class CustomSchemaConverter implements SchemaConverter {

private final SchemaGenerator schemaGenerator;
private final ObjectMapper objectMapper;
public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}";

public CustomSchemaConverter() {
objectMapper = new ObjectMapper();
var jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
JacksonOption.RESPECT_JSONPROPERTY_ORDER);
var configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12,
OptionPreset.PLAIN_JSON)
.with(jacksonModule)
.with(builder -> builder.forTypesInGeneral()
.withTypeAttributeOverride(
(collectedTypeAttributes, scope, context) -> collectedTypeAttributes
.put("myCustomProperty", true)))
.without(Option.SCHEMA_VERSION_INDICATOR);
var config = configBuilder.build();
schemaGenerator = new SchemaGenerator(config);
}

@Override
public JsonNode convert(Class<?> clazz) {
JsonNode jsonSchema;
try {
jsonSchema = schemaGenerator.generateSchema(clazz);
if (jsonSchema.get("properties") == null) {
jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
}

} catch (Exception e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.", clazz.getName(),
e);
}
return jsonSchema;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.github.sashirestela.openai.support;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import io.github.sashirestela.openai.common.function.SchemaConverter;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.Test;

import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS;
import static org.junit.jupiter.api.Assertions.assertEquals;

class CustomSchemaConverterTest {

private static SchemaConverter schemaConverter = new CustomSchemaConverter();

@Test
void shouldGenerateFullJsonSchemaWhenClassHasSomeFields() {
var actualJsonSchema = schemaConverter.convert(TestClass.class).toString();
var expectedJsonSchema = "{\"type\":\"object\",\"properties\":{\"first\":{\"type\":\"string\",\"myCustomProperty\":true},\"second\":{\"type\":\"integer\",\"myCustomProperty\":true}},\"required\":[\"first\"],\"myCustomProperty\":true}";
assertEquals(expectedJsonSchema, actualJsonSchema);
}

@Test
void shouldGenerateEmptyJsonSchemaWhenClassHasNoFields() {
var actualJsonSchema = schemaConverter.convert(EmptyClass.class).toString();
var expectedJsonSchema = JSON_EMPTY_CLASS;
assertEquals(expectedJsonSchema, actualJsonSchema);
}

@Test
void shouldGenerateOrderedJsonSchemaWhenClassHasJsonPropertyOrderAnnotation() {
var actualJsonSchema = schemaConverter.convert(OrderedTestClass.class).toString();
var expectedJsonSchema = "{\"type\":\"object\",\"properties\":{\"first\":{\"type\":\"string\",\"myCustomProperty\":true},\"second\":{\"type\":\"integer\",\"myCustomProperty\":true},\"third\":{\"type\":\"string\",\"myCustomProperty\":true}},\"required\":[\"first\"],\"myCustomProperty\":true}";
assertEquals(expectedJsonSchema, actualJsonSchema);
}

@NoArgsConstructor
@AllArgsConstructor
@Getter
static class TestClass {

@JsonProperty(required = true)
public String first;

public Integer second;

}

static class EmptyClass {
}

@NoArgsConstructor
@AllArgsConstructor
@Getter
@JsonPropertyOrder({ "first", "second", "third" })
static class OrderedTestClass {

@JsonProperty(required = true)
public String first;

public Integer second;

public String third;

}

}