Skip to content

Commit 8a891b2

Browse files
committed
Updating FineTuning API including DPO
1 parent 5b5ed48 commit 8a891b2

File tree

6 files changed

+124
-4
lines changed

6 files changed

+124
-4
lines changed

src/main/java/io/github/sashirestela/openai/domain/finetuning/FineTuning.java

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class FineTuning {
3131
private List<Integration> integrations;
3232
private Integer seed;
3333
private Integer estimatedFinish;
34+
private MethodFineTunning method;
3435

3536
@NoArgsConstructor
3637
@Getter

src/main/java/io/github/sashirestela/openai/domain/finetuning/FineTuningRequest.java

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ public class FineTuningRequest {
2525

2626
private String validationFile;
2727

28+
/**
29+
* @deprecated OpenAI has deperecated this field in favor of method, and should be passed in under
30+
* the method parameter.
31+
*/
32+
@Deprecated(since = "3.12.0", forRemoval = true)
2833
private HyperParams hyperparameters;
2934

3035
private String suffix;
@@ -34,4 +39,6 @@ public class FineTuningRequest {
3439

3540
private Integer seed;
3641

42+
private MethodFineTunning method;
43+
3744
}

src/main/java/io/github/sashirestela/openai/domain/finetuning/HyperParams.java

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
2222
public class HyperParams {
2323

24+
@ObjectType(baseClass = Integer.class)
25+
@ObjectType(baseClass = String.class)
26+
private Object beta;
27+
2428
@ObjectType(baseClass = Integer.class)
2529
@ObjectType(baseClass = String.class)
2630
private Object batchSize;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package io.github.sashirestela.openai.domain.finetuning;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
5+
import com.fasterxml.jackson.annotation.JsonProperty;
6+
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
7+
import com.fasterxml.jackson.databind.annotation.JsonNaming;
8+
import lombok.Getter;
9+
import lombok.NoArgsConstructor;
10+
import lombok.ToString;
11+
12+
@Getter
13+
@ToString
14+
@NoArgsConstructor
15+
@JsonInclude(Include.NON_EMPTY)
16+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
17+
public class MethodFineTunning {
18+
19+
private MethodType type;
20+
private Supervised supervised;
21+
private Dpo dpo;
22+
23+
private MethodFineTunning(MethodType type, Supervised supervised, Dpo dpo) {
24+
this.type = type;
25+
this.supervised = supervised;
26+
this.dpo = dpo;
27+
}
28+
29+
public static MethodFineTunning supervised(HyperParams hyperParameters) {
30+
return new MethodFineTunning(MethodType.SUPERVISED, new Supervised(hyperParameters), null);
31+
}
32+
33+
public static MethodFineTunning dpo(HyperParams hyperParameters) {
34+
return new MethodFineTunning(MethodType.DPO, null, new Dpo(hyperParameters));
35+
}
36+
37+
public enum MethodType {
38+
39+
@JsonProperty("supervised")
40+
SUPERVISED,
41+
42+
@JsonProperty("dpo")
43+
DPO;
44+
45+
}
46+
47+
@Getter
48+
@ToString
49+
@NoArgsConstructor
50+
@JsonInclude(Include.NON_EMPTY)
51+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
52+
public static class Supervised {
53+
54+
HyperParams hyperParameters;
55+
56+
public Supervised(HyperParams hyperParameters) {
57+
this.hyperParameters = hyperParameters;
58+
}
59+
60+
}
61+
62+
@Getter
63+
@ToString
64+
@NoArgsConstructor
65+
@JsonInclude(Include.NON_EMPTY)
66+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
67+
public static class Dpo {
68+
69+
HyperParams hyperParameters;
70+
71+
public Dpo(HyperParams hyperParameters) {
72+
this.hyperParameters = hyperParameters;
73+
}
74+
75+
}
76+
77+
}

src/test/java/io/github/sashirestela/openai/domain/finetuning/FineTuningDomainTest.java

+34-3
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,43 @@ static void setup() {
3131
}
3232

3333
@Test
34-
void testFineTuningsCreate() throws IOException {
34+
void testFineTuningsCreateDpo() throws IOException {
3535
DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/finetunings_create.json");
3636
var fineTuningRequest = FineTuningRequest.builder()
3737
.trainingFile("fileId")
3838
.validationFile("fileId")
3939
.model("gpt-3.5-turbo-1106")
40-
.hyperparameters(HyperParams.builder()
40+
.suffix("suffix")
41+
.integration(Integration.builder()
42+
.type(IntegrationType.WANDB)
43+
.wandb(WandbIntegration.builder()
44+
.project("my-wandb-project")
45+
.name("ft-run-display-name")
46+
.entity("testing")
47+
.tag("first-experiment")
48+
.tag("v2")
49+
.build())
50+
.build())
51+
.seed(99)
52+
.method(MethodFineTunning.dpo(HyperParams.builder()
53+
.beta("auto")
4154
.batchSize("auto")
4255
.learningRateMultiplier("auto")
4356
.nEpochs("auto")
44-
.build())
57+
.build()))
58+
.build();
59+
var fineTuningResponse = openAI.fineTunings().create(fineTuningRequest).join();
60+
System.out.println(fineTuningResponse);
61+
assertNotNull(fineTuningResponse);
62+
}
63+
64+
@Test
65+
void testFineTuningsCreateSupervised() throws IOException {
66+
DomainTestingHelper.get().mockForObject(httpClient, "src/test/resources/finetunings_create.json");
67+
var fineTuningRequest = FineTuningRequest.builder()
68+
.trainingFile("fileId")
69+
.validationFile("fileId")
70+
.model("gpt-3.5-turbo-1106")
4571
.suffix("suffix")
4672
.integration(Integration.builder()
4773
.type(IntegrationType.WANDB)
@@ -54,6 +80,11 @@ void testFineTuningsCreate() throws IOException {
5480
.build())
5581
.build())
5682
.seed(99)
83+
.method(MethodFineTunning.supervised(HyperParams.builder()
84+
.batchSize("auto")
85+
.learningRateMultiplier("auto")
86+
.nEpochs("auto")
87+
.build()))
5788
.build();
5889
var fineTuningResponse = openAI.fineTunings().create(fineTuningRequest).join();
5990
System.out.println(fineTuningResponse);
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{ "object": "fine_tuning.job", "id": "ftjob-35j8EBrZVsuyFFe2OD8Tkmvd", "model": "gpt-3.5-turbo-1106", "created_at": 1700533111, "finished_at": null, "fine_tuned_model": null, "organization_id": "org-4WdgDKZ75eLPEH6zqX5hFd5e", "result_files": [], "status": "validating_files", "validation_file": null, "training_file": "file-0e5BDWQYA1KsguTJRCCXqAa2", "hyperparameters": { "n_epochs": "auto", "batch_size": "auto", "learning_rate_multiplier": "auto" }, "trained_tokens": null, "error": null, "integrations":[{"type":"wandb","wandb":{"project":"my-wandb-project","name":"ft-run-display-name","tags":["first-experiment","v2"]}}], "seed": 99 }
1+
{"object":"fine_tuning.job","id":"ftjob-35j8EBrZVsuyFFe2OD8Tkmvd","model":"gpt-3.5-turbo-1106","created_at":1700533111,"finished_at":null,"fine_tuned_model":null,"organization_id":"org-4WdgDKZ75eLPEH6zqX5hFd5e","result_files":[],"status":"validating_files","validation_file":null,"training_file":"file-0e5BDWQYA1KsguTJRCCXqAa2","trained_tokens":null,"error":null,"integrations":[{"type":"wandb","wandb":{"project":"my-wandb-project","name":"ft-run-display-name","tags":["first-experiment","v2"]}}],"seed":99,"method":{"type":"supervised","supervised":{"hyperparameters":{"batch_size":"auto","learning_rate_multiplier":"auto","n_epochs":"auto"}}}}

0 commit comments

Comments
 (0)