Skip to content

Commit 34aca39

Browse files
authored
Merge pull request #56 from eli5-org/llm-mlx
Support open source LLMs with mlx-lm
2 parents 803d8b2 + 8f883c1 commit 34aca39

File tree

5 files changed

+353
-20
lines changed

5 files changed

+353
-20
lines changed

docs/source/_notebooks/explain_llm_logprobs.rst

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ about its predictions:
1111

1212
LLM token probabilities visualized with eli5.explain_prediction
1313

14+
1. OpenAI models
15+
----------------
16+
1417
To follow this tutorial you need the ``openai`` library installed and
1518
working.
1619

@@ -64,10 +67,10 @@ properties from a free-form product description:
6467
json
6568
{
6669
"materials": ["metal"],
67-
"type": "table lamp",
70+
"type": "task lighting",
6871
"color": "silky matte grey",
6972
"price": 150.00,
70-
"summary": "Stay is a flexible and elegant table lamp designed by Maria Berntsen."
73+
"summary": "Stay table lamp with adjustable arm and head for optimal task lighting."
7174
}
7275
7376
@@ -311,8 +314,8 @@ We can obtain the original prediction from the explanation object via
311314
``explanation.targets[0].target.message.content`` to get the prediction
312315
text.
313316

314-
Limitations
315-
-----------
317+
2. Limitations
318+
--------------
316319

317320
Even though above the model confidence matched our expectations, it’s
318321
not always the case. For example, if we use “Chain of Thought”
@@ -531,6 +534,121 @@ temperatures:
531534

532535

533536

537+
538+
539+
540+
541+
542+
543+
544+
545+
546+
547+
548+
549+
550+
551+
552+
553+
554+
555+
556+
557+
558+
3. Open Source and other models
559+
-------------------------------
560+
561+
If an API endpoint can provide ``logprobs`` in the right format, then it
562+
should work. However few APIs or libraries do provide it, even for open
563+
source models. One library which is know to work is ``mlx_lm`` (Mac OS
564+
only), e.g. if you start the server like this:
565+
566+
::
567+
568+
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
569+
570+
Then you can explain predictions with a custom client:
571+
572+
.. code:: ipython3
573+
574+
client_custom = openai.OpenAI(base_url="http://localhost:8080/v1", api_key="dummy")
575+
eli5.explain_prediction(
576+
client_custom,
577+
prompt + ' Price should never be zero.',
578+
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
579+
)
580+
581+
582+
583+
584+
.. raw:: html
585+
586+
587+
<style>
588+
table.eli5-weights tr:hover {
589+
filter: brightness(85%);
590+
}
591+
</style>
592+
593+
594+
595+
596+
597+
598+
599+
600+
601+
602+
603+
604+
605+
606+
607+
608+
609+
610+
611+
612+
613+
614+
615+
616+
617+
618+
619+
620+
621+
622+
623+
624+
<p style="margin-bottom: 2.5em; margin-top:0; white-space: pre-wrap;"><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">{
625+
</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969"> </span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> &quot;materials&quot;: [&quot;</span><span style="background-color: hsl(94.45469472360817, 100.00%, 50.00%)" title="0.925">sil</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ky matte grey metal</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984">&quot;],</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">
626+
&quot;type&quot;:</span><span style="background-color: hsl(67.65625435391846, 100.00%, 50.00%)" title="0.616"> &quot;</span><span style="background-color: hsl(54.98115667185111, 100.00%, 50.00%)" title="0.423">Not</span><span style="background-color: hsl(99.15673841086969, 100.00%, 50.00%)" title="0.954"> specified</span><span style="background-color: hsl(60.184714790030306, 100.00%, 50.00%)" title="0.503"> in</span><span style="background-color: hsl(89.21337460784397, 100.00%, 50.00%)" title="0.882"> the</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767"> description</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">&quot;,
627+
&quot;color&quot;: &quot;</span><span style="background-color: hsl(75.83477075768666, 100.00%, 50.00%)" title="0.732">Not</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> specified in the description&quot;,
628+
&quot;price&quot;:</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984"> </span><span style="background-color: hsl(66.38465137200743, 100.00%, 50.00%)" title="0.597">9</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">9.99,</span><span style="background-color: hsl(70.37163543696023, 100.00%, 50.00%)" title="0.656">
629+
</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> &quot;summary&quot;: &quot;</span><span style="background-color: hsl(87.74310691905805, 100.00%, 50.00%)" title="0.869">St</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ay</span><span style="background-color: hsl(96.62538179271525, 100.00%, 50.00%)" title="0.939"> is</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> a</span><span style="background-color: hsl(50.86740763861728, 100.00%, 50.00%)" title="0.362"> flexible</span><span style="background-color: hsl(74.98611970505678, 100.00%, 50.00%)" title="0.720"> and</span><span style="background-color: hsl(53.693017832784676, 100.00%, 50.00%)" title="0.404"> beautiful</span><span style="background-color: hsl(86.3702071513559, 100.00%, 50.00%)" title="0.855"> Dan</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ish</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767">-</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">designed</span><span style="background-color: hsl(83.859705332877, 100.00%, 50.00%)" title="0.829"> table</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> lamp</span><span style="background-color: hsl(83.859705332877, 100.00%, 50.00%)" title="0.829"> with</span><span style="background-color: hsl(63.98782501663244, 100.00%, 50.00%)" title="0.561"> a</span><span style="background-color: hsl(57.717412015697704, 100.00%, 50.00%)" title="0.465"> discre</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">et</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969"> switch</span><span style="background-color: hsl(73.36326933304912, 100.00%, 50.00%)" title="0.698"> and</span><span style="background-color: hsl(52.049284214496396, 100.00%, 50.00%)" title="0.380"> adjust</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">able</span><span style="background-color: hsl(68.98323009885026, 100.00%, 50.00%)" title="0.636"> arm</span><span style="background-color: hsl(90.80119410884282, 100.00%, 50.00%)" title="0.896"> and</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> head</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767">,</span><span style="background-color: hsl(65.7679832651431, 100.00%, 50.00%)" title="0.588"> ideal</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> for</span><span style="background-color: hsl(69.66933255201431, 100.00%, 50.00%)" title="0.646"> office</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984"> task</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> lighting</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969">.&quot;</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">
630+
}</span></p>
631+
632+
633+
634+
635+
636+
637+
638+
639+
640+
641+
642+
643+
644+
645+
646+
647+
648+
649+
650+
651+
534652

535653

536654

eli5/llm/explain_prediction.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import math
2-
from typing import Union
2+
import warnings
3+
from typing import Optional, Union
34

45
import openai
5-
from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletion
6+
from openai.types.chat.chat_completion import (
7+
ChatCompletion, ChatCompletionTokenLogprob, ChoiceLogprobs)
68

79
from eli5.base import Explanation, TargetExplanation, WeightedSpans, DocWeightedSpans
810
from eli5.explain import explain_prediction
@@ -49,15 +51,15 @@ def explain_prediction_openai_logprobs(logprobs: ChoiceLogprobs, doc=None):
4951

5052
@explain_prediction.register(ChatCompletion)
5153
def explain_prediction_openai_completion(
52-
chat_completion: ChatCompletion, doc=None):
54+
completion: ChatCompletion, doc=None):
5355
""" Creates an explanation of the ChatCompletion's logprobs
5456
highlighting them proportionally to the log probability.
5557
More likely tokens are highlighted in green,
5658
while unlikely tokens are highlighted in red.
5759
``doc`` argument is ignored.
5860
"""
5961
targets = []
60-
for choice in chat_completion.choices:
62+
for choice in completion.choices:
6163
if choice.logprobs is None:
6264
raise ValueError('Predictions must be obtained with logprobs enabled')
6365
target, = explain_prediction_openai_logprobs(choice.logprobs).targets
@@ -92,8 +94,54 @@ def explain_prediction_openai_client(
9294
else:
9395
messages = doc
9496
kwargs['logprobs'] = True
95-
chat_completion = client.chat.completions.create(
97+
completion = client.chat.completions.create(
9698
messages=messages, # type: ignore
9799
model=model,
98100
**kwargs)
99-
return explain_prediction_openai_completion(chat_completion)
101+
for choice in completion.choices:
102+
_recover_logprobs(choice.logprobs, model)
103+
if choice.logprobs is None:
104+
raise ValueError('logprobs not found, likely API does not support them')
105+
if choice.logprobs.content is None:
106+
raise ValueError(f'logprobs.content is empty: {choice.logprobs}')
107+
return explain_prediction_openai_completion(completion)
108+
109+
110+
def _recover_logprobs(logprobs: Optional[ChoiceLogprobs], model: str):
111+
""" Some servers don't populate logprobs.content, try to recover it.
112+
"""
113+
if logprobs is None:
114+
return
115+
if logprobs.content is not None:
116+
return
117+
if not (
118+
getattr(logprobs, 'token_logprobs', None) and
119+
getattr(logprobs, 'tokens', None)):
120+
return
121+
assert hasattr(logprobs, 'token_logprobs') # for mypy
122+
assert hasattr(logprobs, 'tokens') # for mypy
123+
try:
124+
import tokenizers
125+
except ImportError:
126+
warnings.warn('tokenizers library required to recover logprobs.content')
127+
return
128+
try:
129+
tokenizer = tokenizers.Tokenizer.from_pretrained(model)
130+
except Exception:
131+
warnings.warn(f'could not load tokenizer for {model} with tokenizers library')
132+
return
133+
assert len(logprobs.token_logprobs) == len(logprobs.tokens)
134+
# get tokens as strings with spaces, is there any better way?
135+
text = tokenizer.decode(logprobs.tokens)
136+
encoded = tokenizer.encode(text, add_special_tokens=False)
137+
text_tokens = [text[start:end] for (start, end) in encoded.offsets]
138+
logprobs.content = []
139+
for logprob, token in zip(logprobs.token_logprobs, text_tokens):
140+
logprobs.content.append(
141+
ChatCompletionTokenLogprob(
142+
token=token,
143+
bytes=list(map(int, token.encode('utf8'))),
144+
logprob=logprob,
145+
top_logprobs=[], # we could recover that too
146+
)
147+
)

notebooks/explain_llm_logprobs.ipynb

Lines changed: 131 additions & 3 deletions
Large diffs are not rendered by default.

tests/test_llm_explain_prediction.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import Mock
44

55
pytest.importorskip('openai')
6+
pytest.importorskip('transformers')
67
from openai.types.chat.chat_completion import (
78
ChoiceLogprobs,
89
ChatCompletion,
@@ -11,6 +12,7 @@
1112
Choice,
1213
)
1314
from openai import Client
15+
import transformers
1416

1517
import eli5
1618
from eli5.base import Explanation
@@ -40,20 +42,28 @@ def example_logprobs():
4042

4143
@pytest.fixture
4244
def example_completion(example_logprobs):
45+
return create_completion(
46+
model='gpt-4o-2024-08-06',
47+
logprobs=example_logprobs,
48+
message=ChatCompletionMessage(
49+
content=''.join(x.token for x in example_logprobs.content),
50+
role='assistant',
51+
),
52+
)
53+
54+
55+
def create_completion(model, logprobs, message):
4356
return ChatCompletion(
4457
id='chatcmpl-x',
4558
created=1743590849,
46-
model='gpt-4o-2024-08-06',
59+
model=model,
4760
object='chat.completion',
4861
choices=[
4962
Choice(
50-
logprobs=example_logprobs,
63+
logprobs=logprobs,
5164
finish_reason='stop',
5265
index=0,
53-
message=ChatCompletionMessage(
54-
content=''.join(x.token for x in example_logprobs.content),
55-
role='assistant',
56-
),
66+
message=message,
5767
)
5868
],
5969
)
@@ -100,7 +110,35 @@ def __init__(self, chat_return_value):
100110
def test_explain_prediction_openai_client(monkeypatch, example_completion):
101111
client = MockClient(example_completion)
102112

103-
explanation = eli5.explain_prediction(client, doc="Hello world", model="gpt-4o")
113+
explanation = eli5.explain_prediction(client, doc="Hello world world", model="gpt-4o")
114+
_assert_explanation_structure_and_html(explanation)
115+
116+
client.chat.completions.create.assert_called_once()
117+
118+
119+
def test_explain_prediction_openai_client_mlx(monkeypatch):
120+
model = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
121+
tokenizer = transformers.AutoTokenizer.from_pretrained(model)
122+
123+
text = 'Hello world world'
124+
tokens = tokenizer.encode(text, add_special_tokens=False)
125+
assert len(tokens) == 3
126+
logprobs = ChoiceLogprobs(
127+
token_logprobs=[
128+
math.log(0.9),
129+
math.log(0.2),
130+
math.log(0.4),
131+
],
132+
tokens=tokens,
133+
)
134+
completion = create_completion(
135+
model=model,
136+
logprobs=logprobs,
137+
message=ChatCompletionMessage(content=text, role='assistant'),
138+
)
139+
client = MockClient(completion)
140+
141+
explanation = eli5.explain_prediction(client, doc=text, model=model)
104142
_assert_explanation_structure_and_html(explanation)
105143

106144
client.chat.completions.create.assert_called_once()

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ deps=
3030
pandas
3131
sklearn-crfsuite
3232
openai
33+
tokenizers
3334
commands=
3435
pip install -e .
3536
py.test --doctest-modules \

0 commit comments

Comments
 (0)