Skip to content

Commit cd95cbe

Browse files
authored
Merge pull request #10 from ittia-research/dev
change all LLM calling to DSPy, increase citation token limit
2 parents 44b7391 + 4e9326e commit cd95cbe

17 files changed

+171
-175
lines changed

.env

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
EMBEDDING_API_KEY=ollama:abc
22
EMBEDDING_MODEL_DEPLOY=api
33
EMBEDDING_MODEL_NAME=jina/jina-embeddings-v2-base-en
4+
INDEX_CHUNK_SIZES=[2048, 512, 128]
45
LLM_MODEL_NAME=google/gemma-2-27b-it
56
OLLAMA_BASE_URL=http://ollama:11434
67
OPENAI_API_KEY=sk-proj-aaaaaaaaaaaaaaaaa
@@ -10,5 +11,4 @@ RERANK_MODEL_DEPLOY=local
1011
RERANK_MODEL_NAME=BAAI/bge-reranker-v2-m3
1112
RERANK_BASE_URL=http://xinference:9997/v1
1213
SEARCH_BASE_URL=https://s.jina.ai
13-
THREAD_BUILD_INDEX=12
14-
RAG_CHUNK_SIZES=[4096, 1024, 256]
14+
THREAD_BUILD_INDEX=12

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ DSPy:
106106
### Reports
107107
- [ ] AI-generated misinformation
108108
### Factcheck
109+
- https://www.snopes.com
109110
- https://www.bmi.bund.de/SharedDocs/schwerpunkte/EN/disinformation/examples-of-russian-disinformation-and-the-facts.html
110111
### Resources
111112
#### Inference

src/llm.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

src/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi.responses import Response, JSONResponse, HTMLResponse, PlainTextResponse, FileResponse
55
import logging
66

7-
import llm, utils, pipeline
7+
import utils, pipeline
88

99
logging.basicConfig(
1010
level=logging.INFO,
@@ -18,7 +18,7 @@
1818
async def fact_check(input):
1919
status = 500
2020
logger.info(f"Fact checking: {input}")
21-
statements = await run_in_threadpool(llm.get_statements, input)
21+
statements = await run_in_threadpool(pipeline.get_statements, input)
2222
logger.info(f"statements: {statements}")
2323
if not statements:
2424
raise HTTPException(status_code=status, detail="No statements found")
@@ -29,11 +29,11 @@ async def fact_check(input):
2929
if not statement:
3030
continue
3131
logger.info(f"statement: {statement}")
32-
keywords = await run_in_threadpool(llm.get_search_keywords, statement)
33-
if not keywords:
32+
query = await run_in_threadpool(pipeline.get_search_query, statement)
33+
if not query:
3434
continue
35-
logger.info(f"keywords: {keywords}")
36-
search = await utils.search(keywords)
35+
logger.info(f"search query: {query}")
36+
search = await utils.search(query)
3737
if not search:
3838
fail_search = True
3939
continue

src/modules/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import dspy
2+
3+
from settings import settings
4+
5+
# set DSPy default language model
6+
llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n')
7+
dspy.settings.configure(lm=llm)
8+
9+
# LM with higher token limits
10+
llm_long = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=500, stop='\n\n')
11+
12+
from .citation import Citation
13+
from .ollama_embedding import OllamaEmbedding
14+
from .retrieve import LlamaIndexRM
15+
from .search_query import SearchQuery
16+
from .statements import Statements
17+
from .verdict import Verdict

src/modules/citation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import dspy
2+
3+
# TODO: citation needs higher token limits
4+
class GenerateCitedParagraph(dspy.Signature):
5+
"""Generate a paragraph with citations."""
6+
context = dspy.InputField(desc="may contain relevant facts")
7+
statement = dspy.InputField()
8+
verdict = dspy.InputField()
9+
paragraph = dspy.OutputField(desc="includes citations")
10+
11+
"""Generate citation from context and verdict"""
12+
class Citation(dspy.Module):
13+
def __init__(self):
14+
super().__init__()
15+
self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph)
16+
17+
def forward(self, statement, context, verdict):
18+
citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict)
19+
pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context)
20+
return pred
File renamed without changes.

src/retrieve.py renamed to src/modules/retrieve.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
from llama_index.core import (
1010
Document,
11-
ServiceContext,
1211
Settings,
1312
StorageContext,
1413
VectorStoreIndex,
15-
load_index_from_storage,
1614
)
1715
from llama_index.core.node_parser import HierarchicalNodeParser, get_leaf_nodes
1816
from llama_index.core.retrievers import AutoMergingRetriever
@@ -29,7 +27,7 @@
2927
jinaai_rerank.API_URL = settings.RERANK_BASE_URL + "/rerank" # switch to on-premise
3028

3129
# todo: high lantency between client and the ollama embedding server will slow down embedding a lot
32-
from ollama_embedding import OllamaEmbedding
30+
from . import OllamaEmbedding
3331

3432
# todo: improve embedding performance
3533
if settings.EMBEDDING_MODEL_DEPLOY == "local":
@@ -132,7 +130,7 @@ def build_index(self, docs):
132130
if docs:
133131
self.index, self.storage_context = self.build_automerging_index(
134132
docs,
135-
chunk_sizes=settings.RAG_CHUNK_SIZES,
133+
chunk_sizes=settings.INDEX_CHUNK_SIZES,
136134
) # TODO: try to retrieve directly
137135

138136
def retrieve(self, query):

src/modules/search_query.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import dspy
2+
import logging
3+
4+
"""Notes: LLM will choose a direction based on known facts"""
5+
class GenerateSearchEngineQuery(dspy.Signature):
6+
"""Write a search engine query that will help retrieve info related to the statement."""
7+
statement = dspy.InputField()
8+
query = dspy.OutputField()
9+
10+
class SearchQuery(dspy.Module):
11+
def __init__(self):
12+
super().__init__()
13+
self.generate_query = dspy.ChainOfThought(GenerateSearchEngineQuery)
14+
15+
def forward(self, statement):
16+
query = self.generate_query(statement=statement)
17+
logging.info(f"DSPy CoT search query: {query}")
18+
return query.query

src/modules/statements.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import dspy
2+
import logging
3+
from pydantic import BaseModel, Field
4+
from typing import List
5+
6+
# references: https://github.com/weaviate/recipes/blob/main/integrations/llm-frameworks/dspy/4.Structured-Outputs-with-DSPy.ipynb
7+
class Output(BaseModel):
8+
statements: List = Field(description="A list of key statements")
9+
10+
# TODO: test consistency especially when content contains false claims
11+
class GenerateStatements(dspy.Signature):
12+
"""Extract the original statements from given content without fact check."""
13+
content: str = dspy.InputField(desc="The content to summarize")
14+
output: Output = dspy.OutputField()
15+
16+
class Statements(dspy.Module):
17+
def __init__(self):
18+
super().__init__()
19+
self.generate_statements = dspy.TypedChainOfThought(GenerateStatements, max_retries=6)
20+
21+
def forward(self, content):
22+
statements = self.generate_statements(content=content)
23+
logging.info(f"DSPy CoT statements: {statements}")
24+
return statements.output.statements
Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
import dspy
22
from dsp.utils import deduplicate
33

4-
from retrieve import LlamaIndexRM
5-
from settings import settings
6-
7-
llm = dspy.OpenAI(model=settings.LLM_MODEL_NAME, api_base=f"{settings.OPENAI_BASE_URL}/", max_tokens=200, stop='\n\n')
8-
dspy.settings.configure(lm=llm)
9-
104
class CheckStatementFaithfulness(dspy.Signature):
115
"""Verify that the statement is based on the provided context."""
126
context = dspy.InputField(desc="facts here are assumed to be true")
@@ -19,14 +13,6 @@ class GenerateSearchQuery(dspy.Signature):
1913
statement = dspy.InputField()
2014
query = dspy.OutputField()
2115

22-
# TODO: citation needs higher token limits
23-
class GenerateCitedParagraph(dspy.Signature):
24-
"""Generate a paragraph with citations."""
25-
context = dspy.InputField(desc="may contain relevant facts")
26-
statement = dspy.InputField()
27-
verdict = dspy.InputField()
28-
paragraph = dspy.OutputField(desc="includes citations")
29-
3016
"""
3117
SimplifiedBaleen module
3218
Avoid unnecessary content in module cause MIPROv2 optimizer will analize modules.
@@ -39,7 +25,7 @@ class GenerateCitedParagraph(dspy.Signature):
3925
- remove some contexts incase token reaches to max
4026
- does different InputField name other than answer compateble with dspy evaluate
4127
"""
42-
class ContextVerdict(dspy.Module):
28+
class Verdict(dspy.Module):
4329
def __init__(self, retrieve, passages_per_hop=3, max_hops=3):
4430
super().__init__()
4531
# self.generate_query = dspy.ChainOfThought(GenerateSearchQuery) # IMPORTANT: solves error `list index out of range`
@@ -59,42 +45,4 @@ def forward(self, statement):
5945
verdict = self.generate_verdict(context=context, statement=statement)
6046
pred = dspy.Prediction(answer=verdict.verdict, rationale=verdict.rationale, context=context)
6147
return pred
62-
63-
"""Generate citation from context and verdict"""
64-
class Citation(dspy.Module):
65-
def __init__(self):
66-
super().__init__()
67-
self.generate_cited_paragraph = dspy.ChainOfThought(GenerateCitedParagraph)
68-
69-
def forward(self, statement, context, verdict):
70-
citation = self.generate_cited_paragraph(context=context, statement=statement, verdict=verdict)
71-
pred = dspy.Prediction(verdict=verdict, citation=citation.paragraph, context=context)
72-
return pred
73-
74-
"""
75-
Get both verdict and citation.
76-
77-
Args:
78-
retrieve: dspy.Retrieve
79-
"""
80-
class VerdictCitation():
81-
def __init__(
82-
self,
83-
docs,
84-
):
85-
self.retrieve = LlamaIndexRM(docs=docs)
86-
87-
# loading compiled ContextVerdict
88-
self.context_verdict = ContextVerdict(retrieve=self.retrieve)
89-
self.context_verdict.load("./optimizers/verdict_MIPROv2.json")
90-
91-
def get(self, statement):
92-
rep = self.context_verdict(statement)
93-
context = rep.context
94-
verdict = rep.answer
95-
96-
rep = Citation()(statement=statement, context=context, verdict=verdict)
97-
citation = rep.citation
98-
99-
return verdict, citation
100-
48+

src/pipeline.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/pipeline/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .common import get_search_query, get_statements, get_verdict
2+
from .verdict_citation import VerdictCitation

src/pipeline/common.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
import utils
3+
from modules import SearchQuery, Statements
4+
from .verdict_citation import VerdictCitation
5+
6+
def get_statements(content):
7+
"""Get list of statements from a text string"""
8+
try:
9+
statements = Statements()(content=content)
10+
except Exception as e:
11+
logging.error(f"Getting statements failed: {e}")
12+
statements = []
13+
14+
return statements
15+
16+
def get_search_query(statement):
17+
"""Get search query from one statement"""
18+
19+
try:
20+
query = SearchQuery()(statement=statement)
21+
except Exception as e:
22+
logging.error(f"Getting search query from statement '{statement}' failed: {e}")
23+
query = ""
24+
25+
return query
26+
27+
def get_verdict(search_json, statement):
28+
docs = utils.search_json_to_docs(search_json)
29+
rep = VerdictCitation(docs=docs).get(statement=statement)
30+
31+
return {
32+
"verdict": rep[0],
33+
"citation": rep[1],
34+
"statement": statement,
35+
}
36+

0 commit comments

Comments
 (0)