Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions docs/source/_notebooks/explain_llm_logprobs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ about its predictions:

LLM token probabilities visualized with eli5.explain_prediction

1. OpenAI models
----------------

To follow this tutorial you need the ``openai`` library installed and
working.

Expand Down Expand Up @@ -64,10 +67,10 @@ properties from a free-form product description:
json
{
"materials": ["metal"],
"type": "table lamp",
"type": "task lighting",
"color": "silky matte grey",
"price": 150.00,
"summary": "Stay is a flexible and elegant table lamp designed by Maria Berntsen."
"summary": "Stay table lamp with adjustable arm and head for optimal task lighting."
}


Expand Down Expand Up @@ -311,8 +314,8 @@ We can obtain the original prediction from the explanation object via
``explanation.targets[0].target.message.content`` to get the prediction
text.

Limitations
-----------
2. Limitations
--------------

Even though above the model confidence matched our expectations, it’s
not always the case. For example, if we use “Chain of Thought”
Expand Down Expand Up @@ -531,6 +534,121 @@ temperatures:
























3. Open Source and other models
-------------------------------

If an API endpoint can provide ``logprobs`` in the right format, then it
should work. However few APIs or libraries do provide it, even for open
source models. One library which is know to work is ``mlx_lm`` (Mac OS
only), e.g. if you start the server like this:

::

mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit

Then you can explain predictions with a custom client:

.. code:: ipython3

client_custom = openai.OpenAI(base_url="http://localhost:8080/v1", api_key="dummy")
eli5.explain_prediction(
client_custom,
prompt + ' Price should never be zero.',
model="mlx-community/Mistral-7B-Instruct-v0.3-4bit",
)




.. raw:: html


<style>
table.eli5-weights tr:hover {
filter: brightness(85%);
}
</style>
































<p style="margin-bottom: 2.5em; margin-top:0; white-space: pre-wrap;"><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">{
</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969"> </span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> &quot;materials&quot;: [&quot;</span><span style="background-color: hsl(94.45469472360817, 100.00%, 50.00%)" title="0.925">sil</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ky matte grey metal</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984">&quot;],</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">
&quot;type&quot;:</span><span style="background-color: hsl(67.65625435391846, 100.00%, 50.00%)" title="0.616"> &quot;</span><span style="background-color: hsl(54.98115667185111, 100.00%, 50.00%)" title="0.423">Not</span><span style="background-color: hsl(99.15673841086969, 100.00%, 50.00%)" title="0.954"> specified</span><span style="background-color: hsl(60.184714790030306, 100.00%, 50.00%)" title="0.503"> in</span><span style="background-color: hsl(89.21337460784397, 100.00%, 50.00%)" title="0.882"> the</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767"> description</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">&quot;,
&quot;color&quot;: &quot;</span><span style="background-color: hsl(75.83477075768666, 100.00%, 50.00%)" title="0.732">Not</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> specified in the description&quot;,
&quot;price&quot;:</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984"> </span><span style="background-color: hsl(66.38465137200743, 100.00%, 50.00%)" title="0.597">9</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">9.99,</span><span style="background-color: hsl(70.37163543696023, 100.00%, 50.00%)" title="0.656">
</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> &quot;summary&quot;: &quot;</span><span style="background-color: hsl(87.74310691905805, 100.00%, 50.00%)" title="0.869">St</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ay</span><span style="background-color: hsl(96.62538179271525, 100.00%, 50.00%)" title="0.939"> is</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> a</span><span style="background-color: hsl(50.86740763861728, 100.00%, 50.00%)" title="0.362"> flexible</span><span style="background-color: hsl(74.98611970505678, 100.00%, 50.00%)" title="0.720"> and</span><span style="background-color: hsl(53.693017832784676, 100.00%, 50.00%)" title="0.404"> beautiful</span><span style="background-color: hsl(86.3702071513559, 100.00%, 50.00%)" title="0.855"> Dan</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">ish</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767">-</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">designed</span><span style="background-color: hsl(83.859705332877, 100.00%, 50.00%)" title="0.829"> table</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> lamp</span><span style="background-color: hsl(83.859705332877, 100.00%, 50.00%)" title="0.829"> with</span><span style="background-color: hsl(63.98782501663244, 100.00%, 50.00%)" title="0.561"> a</span><span style="background-color: hsl(57.717412015697704, 100.00%, 50.00%)" title="0.465"> discre</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">et</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969"> switch</span><span style="background-color: hsl(73.36326933304912, 100.00%, 50.00%)" title="0.698"> and</span><span style="background-color: hsl(52.049284214496396, 100.00%, 50.00%)" title="0.380"> adjust</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">able</span><span style="background-color: hsl(68.98323009885026, 100.00%, 50.00%)" title="0.636"> arm</span><span style="background-color: hsl(90.80119410884282, 100.00%, 50.00%)" title="0.896"> and</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> head</span><span style="background-color: hsl(78.55493818915068, 100.00%, 50.00%)" title="0.767">,</span><span style="background-color: hsl(65.7679832651431, 100.00%, 50.00%)" title="0.588"> ideal</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> for</span><span style="background-color: hsl(69.66933255201431, 100.00%, 50.00%)" title="0.646"> office</span><span style="background-color: hsl(106.55663140845363, 100.00%, 50.00%)" title="0.984"> task</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000"> lighting</span><span style="background-color: hsl(102.26933456703064, 100.00%, 50.00%)" title="0.969">.&quot;</span><span style="background-color: hsl(120.0, 100.00%, 50.00%)" title="1.000">
}</span></p>
























Expand Down
60 changes: 54 additions & 6 deletions eli5/llm/explain_prediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
from typing import Union
import warnings
from typing import Optional, Union

import openai
from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletion
from openai.types.chat.chat_completion import (
ChatCompletion, ChatCompletionTokenLogprob, ChoiceLogprobs)

from eli5.base import Explanation, TargetExplanation, WeightedSpans, DocWeightedSpans
from eli5.explain import explain_prediction
Expand Down Expand Up @@ -49,15 +51,15 @@ def explain_prediction_openai_logprobs(logprobs: ChoiceLogprobs, doc=None):

@explain_prediction.register(ChatCompletion)
def explain_prediction_openai_completion(
chat_completion: ChatCompletion, doc=None):
completion: ChatCompletion, doc=None):
""" Creates an explanation of the ChatCompletion's logprobs
highlighting them proportionally to the log probability.
More likely tokens are highlighted in green,
while unlikely tokens are highlighted in red.
``doc`` argument is ignored.
"""
targets = []
for choice in chat_completion.choices:
for choice in completion.choices:
if choice.logprobs is None:
raise ValueError('Predictions must be obtained with logprobs enabled')
target, = explain_prediction_openai_logprobs(choice.logprobs).targets
Expand Down Expand Up @@ -92,8 +94,54 @@ def explain_prediction_openai_client(
else:
messages = doc
kwargs['logprobs'] = True
chat_completion = client.chat.completions.create(
completion = client.chat.completions.create(
messages=messages, # type: ignore
model=model,
**kwargs)
return explain_prediction_openai_completion(chat_completion)
for choice in completion.choices:
_recover_logprobs(choice.logprobs, model)
if choice.logprobs is None:
raise ValueError('logprobs not found, likely API does not support them')
if choice.logprobs.content is None:
raise ValueError(f'logprobs.content is empty: {choice.logprobs}')
return explain_prediction_openai_completion(completion)


def _recover_logprobs(logprobs: Optional[ChoiceLogprobs], model: str):
""" Some servers don't populate logprobs.content, try to recover it.
"""
if logprobs is None:
return
if logprobs.content is not None:
return
if not (
getattr(logprobs, 'token_logprobs', None) and
getattr(logprobs, 'tokens', None)):
return
assert hasattr(logprobs, 'token_logprobs') # for mypy
assert hasattr(logprobs, 'tokens') # for mypy
try:
import tokenizers
except ImportError:
warnings.warn('tokenizers library required to recover logprobs.content')
return
try:
tokenizer = tokenizers.Tokenizer.from_pretrained(model)
except Exception:
warnings.warn(f'could not load tokenizer for {model} with tokenizers library')
return
assert len(logprobs.token_logprobs) == len(logprobs.tokens)
# get tokens as strings with spaces, is there any better way?
text = tokenizer.decode(logprobs.tokens)
encoded = tokenizer.encode(text, add_special_tokens=False)
text_tokens = [text[start:end] for (start, end) in encoded.offsets]
logprobs.content = []
for logprob, token in zip(logprobs.token_logprobs, text_tokens):
logprobs.content.append(
ChatCompletionTokenLogprob(
token=token,
bytes=list(map(int, token.encode('utf8'))),
logprob=logprob,
top_logprobs=[], # we could recover that too
)
)
134 changes: 131 additions & 3 deletions notebooks/explain_llm_logprobs.ipynb

Large diffs are not rendered by default.

52 changes: 45 additions & 7 deletions tests/test_llm_explain_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import Mock

pytest.importorskip('openai')
pytest.importorskip('transformers')
from openai.types.chat.chat_completion import (
ChoiceLogprobs,
ChatCompletion,
Expand All @@ -11,6 +12,7 @@
Choice,
)
from openai import Client
import transformers

import eli5
from eli5.base import Explanation
Expand Down Expand Up @@ -40,20 +42,28 @@ def example_logprobs():

@pytest.fixture
def example_completion(example_logprobs):
return create_completion(
model='gpt-4o-2024-08-06',
logprobs=example_logprobs,
message=ChatCompletionMessage(
content=''.join(x.token for x in example_logprobs.content),
role='assistant',
),
)


def create_completion(model, logprobs, message):
return ChatCompletion(
id='chatcmpl-x',
created=1743590849,
model='gpt-4o-2024-08-06',
model=model,
object='chat.completion',
choices=[
Choice(
logprobs=example_logprobs,
logprobs=logprobs,
finish_reason='stop',
index=0,
message=ChatCompletionMessage(
content=''.join(x.token for x in example_logprobs.content),
role='assistant',
),
message=message,
)
],
)
Expand Down Expand Up @@ -100,7 +110,35 @@ def __init__(self, chat_return_value):
def test_explain_prediction_openai_client(monkeypatch, example_completion):
client = MockClient(example_completion)

explanation = eli5.explain_prediction(client, doc="Hello world", model="gpt-4o")
explanation = eli5.explain_prediction(client, doc="Hello world world", model="gpt-4o")
_assert_explanation_structure_and_html(explanation)

client.chat.completions.create.assert_called_once()


def test_explain_prediction_openai_client_mlx(monkeypatch):
model = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
tokenizer = transformers.AutoTokenizer.from_pretrained(model)

text = 'Hello world world'
tokens = tokenizer.encode(text, add_special_tokens=False)
assert len(tokens) == 3
logprobs = ChoiceLogprobs(
token_logprobs=[
math.log(0.9),
math.log(0.2),
math.log(0.4),
],
tokens=tokens,
)
completion = create_completion(
model=model,
logprobs=logprobs,
message=ChatCompletionMessage(content=text, role='assistant'),
)
client = MockClient(completion)

explanation = eli5.explain_prediction(client, doc=text, model=model)
_assert_explanation_structure_and_html(explanation)

client.chat.completions.create.assert_called_once()
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ deps=
pandas
sklearn-crfsuite
openai
tokenizers
commands=
pip install -e .
py.test --doctest-modules \
Expand Down