From 92cd2feaf62bf17f4b14a922f21ac0bb65e72cb9 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 11 Sep 2024 15:11:54 +0300 Subject: [PATCH 1/2] Exclude system prompt from saving of chat history --- Runtime/LLMCharacter.cs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index e05012f2..22893bff 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -239,17 +239,17 @@ protected async Task LoadHistory() } } - protected string GetSavePath(string filename) + public string GetSavePath(string filename) { return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/'); } - protected string GetJsonSavePath(string filename) + public string GetJsonSavePath(string filename) { return GetSavePath(filename + ".json"); } - protected string GetCacheSavePath(string filename) + public string GetCacheSavePath(string filename) { // this is saved already in the Application.persistentDataPath folder return GetSavePath(filename + ".cache"); @@ -648,7 +648,7 @@ public async Task Save(string filename) string filepath = GetJsonSavePath(filename); string dirname = Path.GetDirectoryName(filepath); if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname); - string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat }); + string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) }); File.WriteAllText(filepath, json); string cachepath = GetCacheSavePath(filename); @@ -671,7 +671,9 @@ public async Task Load(string filename) return null; } string json = File.ReadAllText(filepath); - chat = JsonUtility.FromJson(json).chat; + List chatHistory = JsonUtility.FromJson(json).chat; + InitPrompt(true); + chat.AddRange(chatHistory); LLMUnitySetup.Log($"Loaded {filepath}"); string cachepath = GetCacheSavePath(filename); From 8b06785f5a152d475eb9b9eaf7a6432eb3bc3c28 Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Wed, 11 Sep 2024 15:13:28 +0300 Subject: [PATCH 2/2] add LLMCharacter save unit test --- Tests/Runtime/TestLLM.cs | 51 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index edf0907b..9b2bba16 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -246,9 +246,9 @@ public virtual async Task Tests() public void TestInitParameters(int nkeep, int chats) { - Assert.That(llmCharacter.nKeep == nkeep); + Assert.AreEqual(llmCharacter.nKeep, nkeep); Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0); - Assert.That(llmCharacter.chat.Count == chats); + Assert.AreEqual(llmCharacter.chat.Count, chats); } public void TestTokens(List tokens) @@ -410,7 +410,7 @@ public override async Task Tests() public class TestLLM_Double : TestLLM { LLM llm1; - LLMCharacter lLMCharacter1; + LLMCharacter llmCharacter1; public override async Task Init() { @@ -421,8 +421,51 @@ public override async Task Init() llm = CreateLLM(); llmCharacter = CreateLLMCharacter(); llm1 = CreateLLM(); - lLMCharacter1 = CreateLLMCharacter(); + llmCharacter1 = CreateLLMCharacter(); gameObject.SetActive(true); } } + + public class TestLLMCharacter_Save : TestLLM + { + string saveName = "TestLLMCharacter_Save"; + + public override LLMCharacter CreateLLMCharacter() + { + LLMCharacter llmCharacter = base.CreateLLMCharacter(); + llmCharacter.save = saveName; + llmCharacter.saveCache = true; + return llmCharacter; + } + + public override async Task Tests() + { + await base.Tests(); + TestSave(); + } + + public void TestSave() + { + string jsonPath = llmCharacter.GetJsonSavePath(saveName); + string cachePath = llmCharacter.GetCacheSavePath(saveName); + Assert.That(File.Exists(jsonPath)); + Assert.That(File.Exists(cachePath)); + string json = File.ReadAllText(jsonPath); + File.Delete(jsonPath); + File.Delete(cachePath); + + List chatHistory = JsonUtility.FromJson(json).chat; + Assert.AreEqual(chatHistory.Count, 2); + Assert.AreEqual(chatHistory[0].role, llmCharacter.playerName); + Assert.AreEqual(chatHistory[0].content, "hi"); + Assert.AreEqual(chatHistory[1].role, llmCharacter.AIName); + + Assert.AreEqual(llmCharacter.chat.Count, chatHistory.Count + 1); + for (int i = 0; i < chatHistory.Count; i++) + { + Assert.AreEqual(chatHistory[i].role, llmCharacter.chat[i + 1].role); + Assert.AreEqual(chatHistory[i].content, llmCharacter.chat[i + 1].content); + } + } + } }