From ec441b03c4ad60a4d7a703c9c3ef38dd2d11feb3 Mon Sep 17 00:00:00 2001 From: Konstantin Lopuhin Date: Sat, 5 Apr 2025 23:11:40 +0100 Subject: [PATCH 1/2] support open source LLMs with mlx-lm --- .../_notebooks/explain_llm_logprobs.rst | 126 +++++++++++++++- eli5/llm/explain_prediction.py | 52 ++++++- notebooks/explain_llm_logprobs.ipynb | 134 +++++++++++++++++- tests/test_llm_explain_prediction.py | 52 ++++++- tox.ini | 1 + 5 files changed, 346 insertions(+), 19 deletions(-) diff --git a/docs/source/_notebooks/explain_llm_logprobs.rst b/docs/source/_notebooks/explain_llm_logprobs.rst index 6efa805..812fbc5 100644 --- a/docs/source/_notebooks/explain_llm_logprobs.rst +++ b/docs/source/_notebooks/explain_llm_logprobs.rst @@ -11,6 +11,9 @@ about its predictions: LLM token probabilities visualized with eli5.explain_prediction +1. OpenAI models +---------------- + To follow this tutorial you need the ``openai`` library installed and working. @@ -64,10 +67,10 @@ properties from a free-form product description: json { "materials": ["metal"], - "type": "table lamp", + "type": "task lighting", "color": "silky matte grey", "price": 150.00, - "summary": "Stay is a flexible and elegant table lamp designed by Maria Berntsen." + "summary": "Stay table lamp with adjustable arm and head for optimal task lighting." } @@ -311,8 +314,8 @@ We can obtain the original prediction from the explanation object via ``explanation.targets[0].target.message.content`` to get the prediction text. -Limitations ------------ +2. Limitations +-------------- Even though above the model confidence matched our expectations, it’s not always the case. For example, if we use “Chain of Thought” @@ -531,6 +534,121 @@ temperatures: + + + + + + + + + + + + + + + + + + + + + +3. Open Source and other models +------------------------------- + +If an API endpoint can provide ``logprobs`` in the right format, then it +should work. However few APIs or libraries do provide it, even for open +source models. One library which is know to work is ``mlx_lm`` (Mac OS +only), e.g. if you start the server like this: + +:: + + mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit + +Then you can explain predictions with a custom client: + +.. code:: ipython3 + + client_custom = openai.OpenAI(base_url="http://localhost:8080/v1", api_key="dummy") + eli5.explain_prediction( + client_custom, + prompt + ' Price should never be zero.', + model="mlx-community/Mistral-7B-Instruct-v0.3-4bit", + ) + + + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

{ + "materials": ["silky matte grey metal"], + "type": "Not specified in the description", + "color": "Not specified in the description", + "price": 99.99, + "summary": "Stay is a flexible and beautiful Danish-designed table lamp with a discreet switch and adjustable arm and head, ideal for office task lighting." + }

+ + + + + + + + + + + + + + + + + + + + + diff --git a/eli5/llm/explain_prediction.py b/eli5/llm/explain_prediction.py index 8d078be..fdc3666 100644 --- a/eli5/llm/explain_prediction.py +++ b/eli5/llm/explain_prediction.py @@ -1,8 +1,10 @@ import math +import warnings from typing import Union import openai -from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletion +from openai.types.chat.chat_completion import ( + ChatCompletion, ChatCompletionTokenLogprob, ChoiceLogprobs) from eli5.base import Explanation, TargetExplanation, WeightedSpans, DocWeightedSpans from eli5.explain import explain_prediction @@ -49,7 +51,7 @@ def explain_prediction_openai_logprobs(logprobs: ChoiceLogprobs, doc=None): @explain_prediction.register(ChatCompletion) def explain_prediction_openai_completion( - chat_completion: ChatCompletion, doc=None): + completion: ChatCompletion, doc=None): """ Creates an explanation of the ChatCompletion's logprobs highlighting them proportionally to the log probability. More likely tokens are highlighted in green, @@ -57,7 +59,7 @@ def explain_prediction_openai_completion( ``doc`` argument is ignored. """ targets = [] - for choice in chat_completion.choices: + for choice in completion.choices: if choice.logprobs is None: raise ValueError('Predictions must be obtained with logprobs enabled') target, = explain_prediction_openai_logprobs(choice.logprobs).targets @@ -92,8 +94,48 @@ def explain_prediction_openai_client( else: messages = doc kwargs['logprobs'] = True - chat_completion = client.chat.completions.create( + completion = client.chat.completions.create( messages=messages, # type: ignore model=model, **kwargs) - return explain_prediction_openai_completion(chat_completion) + for choice in completion.choices: + logprobs = choice.logprobs + if logprobs is None: + raise ValueError('logprobs not found, likely API does not support them') + if logprobs.content is None: + _recover_logprobs_content(logprobs, model) + if logprobs.content is None: + raise ValueError(f'logprobs.content is empty: {logprobs}') + return explain_prediction_openai_completion(completion) + + +def _recover_logprobs_content(logprobs: ChoiceLogprobs, model: str): + """ Some servers don't populate logprobs.content, try to recover it. + """ + if not (logprobs.token_logprobs and logprobs.tokens): + return + try: + import tokenizers + except ImportError: + warnings.warn('tokenizers library required to recover logprobs.content') + return + try: + tokenizer = tokenizers.Tokenizer.from_pretrained(model) + except Exception: + warnings.warn(f'could not load tokenizer for {model} with tokenizers library') + return + assert len(logprobs.token_logprobs) == len(logprobs.tokens) + # get tokens as strings with spaces, is there any better way? + text = tokenizer.decode(logprobs.tokens) + encoded = tokenizer.encode(text, add_special_tokens=False) + text_tokens = [text[start:end] for (start, end) in encoded.offsets] + logprobs.content = [] + for logprob, token in zip(logprobs.token_logprobs, text_tokens): + logprobs.content.append( + ChatCompletionTokenLogprob( + token=token, + bytes=list(map(int, token.encode('utf8'))), + logprob=logprob, + top_logprobs=[], # we could recover that too + ) + ) diff --git a/notebooks/explain_llm_logprobs.ipynb b/notebooks/explain_llm_logprobs.ipynb index d697ead..5deb1ff 100644 --- a/notebooks/explain_llm_logprobs.ipynb +++ b/notebooks/explain_llm_logprobs.ipynb @@ -13,6 +13,8 @@ "\n", "![LLM token probabilities visualized with eli5.explain_prediction](../docs/source/static/llm-explain-logprobs.png)\n", "\n", + "## 1. OpenAI models\n", + "\n", "To follow this tutorial you need the ``openai`` library installed and working." ] }, @@ -50,10 +52,10 @@ "json\n", "{\n", " \"materials\": [\"metal\"],\n", - " \"type\": \"table lamp\",\n", + " \"type\": \"task lighting\",\n", " \"color\": \"silky matte grey\",\n", " \"price\": 150.00,\n", - " \"summary\": \"Stay is a flexible and elegant table lamp designed by Maria Berntsen.\"\n", + " \"summary\": \"Stay table lamp with adjustable arm and head for optimal task lighting.\"\n", "}\n", "\n" ] @@ -364,7 +366,7 @@ "id": "a136a3c3-f840-403a-9a3a-9b7cd07065f4", "metadata": {}, "source": [ - "## Limitations\n", + "## 2. Limitations\n", "\n", "Even though above the model confidence matched our expectations, it's not always the case.\n", "For example, if we use \"Chain of Thought\" (https://arxiv.org/abs/2201.11903) reasoning,\n", @@ -623,6 +625,132 @@ "source": [ "eli5.explain_prediction(client, prompt_cot, model='gpt-4o')" ] + }, + { + "cell_type": "markdown", + "id": "102dbd61-7b37-4f52-b54e-3372fddc3ae7", + "metadata": {}, + "source": [ + "## 3. Open Source and other models\n", + "\n", + "If an API endpoint can provide ``logprobs`` in the right format, then it should work. However few APIs or libraries do provide it,\n", + "even for open source models. One library which is know to work is `mlx_lm` (Mac OS only), e.g. if you start the server like this:\n", + "\n", + " mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit\n", + "\n", + "Then you can explain predictions with a custom client:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fe88ea0c-27d1-428d-b807-4f461fd5471b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + "

{\n", + " "materials": ["silky matte grey metal"],\n", + " "type": "Not specified in the description",\n", + " "color": "Not specified in the description",\n", + " "price": 99.99,\n", + " "summary": "Stay is a flexible and beautiful Danish-designed table lamp with a discreet switch and adjustable arm and head, ideal for office task lighting."\n", + "}

\n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Explanation(estimator='llm_logprobs', description=None, error=None, method=None, is_regression=False, targets=[TargetExplanation(target=Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='{', bytes=[123], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=-0.03125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='material', bytes=[109, 97, 116, 101, 114, 105, 97, 108], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='s', bytes=[115], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\":', bytes=[34, 58], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' [\"', bytes=[32, 91, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='sil', bytes=[115, 105, 108], logprob=-0.078125, top_logprobs=[]), ChatCompletionTokenLogprob(token='ky', bytes=[107, 121], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' mat', bytes=[32, 109, 97, 116], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='te', bytes=[116, 101], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' grey', bytes=[32, 103, 114, 101, 121], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' metal', bytes=[32, 109, 101, 116, 97, 108], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\"],', bytes=[34, 93, 44], logprob=-0.015625, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='type', bytes=[116, 121, 112, 101], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\":', bytes=[34, 58], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=-0.484375, top_logprobs=[]), ChatCompletionTokenLogprob(token='Not', bytes=[78, 111, 116], logprob=-0.859375, top_logprobs=[]), ChatCompletionTokenLogprob(token=' specified', bytes=[32, 115, 112, 101, 99, 105, 102, 105, 101, 100], logprob=-0.046875, top_logprobs=[]), ChatCompletionTokenLogprob(token=' in', bytes=[32, 105, 110], logprob=-0.6875, top_logprobs=[]), ChatCompletionTokenLogprob(token=' the', bytes=[32, 116, 104, 101], logprob=-0.125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' description', bytes=[32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, 110], logprob=-0.265625, top_logprobs=[]), ChatCompletionTokenLogprob(token='\",', bytes=[34, 44], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='color', bytes=[99, 111, 108, 111, 114], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\":', bytes=[34, 58], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='Not', bytes=[78, 111, 116], logprob=-0.3125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' specified', bytes=[32, 115, 112, 101, 99, 105, 102, 105, 101, 100], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' in', bytes=[32, 105, 110], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' the', bytes=[32, 116, 104, 101], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' description', bytes=[32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, 110], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\",', bytes=[34, 44], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='price', bytes=[112, 114, 105, 99, 101], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\":', bytes=[34, 58], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=-0.015625, top_logprobs=[]), ChatCompletionTokenLogprob(token='9', bytes=[57], logprob=-0.515625, top_logprobs=[]), ChatCompletionTokenLogprob(token='9', bytes=[57], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='.', bytes=[46], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='9', bytes=[57], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='9', bytes=[57], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=',', bytes=[44], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=-0.421875, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='summary', bytes=[115, 117, 109, 109, 97, 114, 121], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='\":', bytes=[34, 58], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' \"', bytes=[32, 34], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='St', bytes=[83, 116], logprob=-0.140625, top_logprobs=[]), ChatCompletionTokenLogprob(token='ay', bytes=[97, 121], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' is', bytes=[32, 105, 115], logprob=-0.0625, top_logprobs=[]), ChatCompletionTokenLogprob(token=' a', bytes=[32, 97], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' flexible', bytes=[32, 102, 108, 101, 120, 105, 98, 108, 101], logprob=-1.015625, top_logprobs=[]), ChatCompletionTokenLogprob(token=' and', bytes=[32, 97, 110, 100], logprob=-0.328125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' beautiful', bytes=[32, 98, 101, 97, 117, 116, 105, 102, 117, 108], logprob=-0.90625, top_logprobs=[]), ChatCompletionTokenLogprob(token=' Dan', bytes=[32, 68, 97, 110], logprob=-0.15625, top_logprobs=[]), ChatCompletionTokenLogprob(token='ish', bytes=[105, 115, 104], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='-', bytes=[45], logprob=-0.265625, top_logprobs=[]), ChatCompletionTokenLogprob(token='des', bytes=[100, 101, 115], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='igned', bytes=[105, 103, 110, 101, 100], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' table', bytes=[32, 116, 97, 98, 108, 101], logprob=-0.1875, top_logprobs=[]), ChatCompletionTokenLogprob(token=' lamp', bytes=[32, 108, 97, 109, 112], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' with', bytes=[32, 119, 105, 116, 104], logprob=-0.1875, top_logprobs=[]), ChatCompletionTokenLogprob(token=' a', bytes=[32, 97], logprob=-0.578125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' discre', bytes=[32, 100, 105, 115, 99, 114, 101], logprob=-0.765625, top_logprobs=[]), ChatCompletionTokenLogprob(token='et', bytes=[101, 116], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' switch', bytes=[32, 115, 119, 105, 116, 99, 104], logprob=-0.03125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' and', bytes=[32, 97, 110, 100], logprob=-0.359375, top_logprobs=[]), ChatCompletionTokenLogprob(token=' adjust', bytes=[32, 97, 100, 106, 117, 115, 116], logprob=-0.96875, top_logprobs=[]), ChatCompletionTokenLogprob(token='able', bytes=[97, 98, 108, 101], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' arm', bytes=[32, 97, 114, 109], logprob=-0.453125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' and', bytes=[32, 97, 110, 100], logprob=-0.109375, top_logprobs=[]), ChatCompletionTokenLogprob(token=' head', bytes=[32, 104, 101, 97, 100], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=',', bytes=[44], logprob=-0.265625, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ideal', bytes=[32, 105, 100, 101, 97, 108], logprob=-0.53125, top_logprobs=[]), ChatCompletionTokenLogprob(token=' for', bytes=[32, 102, 111, 114], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token=' office', bytes=[32, 111, 102, 102, 105, 99, 101], logprob=-0.4375, top_logprobs=[]), ChatCompletionTokenLogprob(token=' task', bytes=[32, 116, 97, 115, 107], logprob=-0.015625, top_logprobs=[]), ChatCompletionTokenLogprob(token=' lighting', bytes=[32, 108, 105, 103, 104, 116, 105, 110, 103], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='.\"', bytes=[46, 34], logprob=-0.03125, top_logprobs=[]), ChatCompletionTokenLogprob(token='\\n', bytes=[10], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='}', bytes=[125], logprob=0.0, top_logprobs=[])], refusal=None, token_logprobs=[0.0, 0.0, -0.03125, 0.0, 0.0, 0.0, 0.0, 0.0, -0.078125, 0.0, 0.0, 0.0, 0.0, 0.0, -0.015625, 0.0, 0.0, 0.0, 0.0, 0.0, -0.484375, -0.859375, -0.046875, -0.6875, -0.125, -0.265625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.3125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.015625, -0.515625, 0.0, 0.0, 0.0, 0.0, 0.0, -0.421875, 0.0, 0.0, 0.0, 0.0, 0.0, -0.140625, 0.0, -0.0625, 0.0, -1.015625, -0.328125, -0.90625, -0.15625, 0.0, -0.265625, 0.0, 0.0, -0.1875, 0.0, -0.1875, -0.578125, -0.765625, 0.0, -0.03125, -0.359375, -0.96875, 0.0, -0.453125, -0.109375, 0.0, -0.265625, -0.53125, 0.0, -0.4375, -0.015625, 0.0, -0.03125, 0.0, 0.0, 0.0], top_logprobs=[[[1139, 0.0]], [[781, 0.0]], [[29473, -0.03125]], [[1113, 0.0]], [[11363, 0.0]], [[29481, 0.0]], [[2032, 0.0]], [[8135, 0.0]], [[23353, -0.078125]], [[5613, 0.0]], [[2378, 0.0]], [[1192, 0.0]], [[16311, 0.0]], [[8271, 0.0]], [[9651, -0.015625]], [[781, 0.0]], [[29473, 0.0]], [[1113, 0.0]], [[1891, 0.0]], [[2032, 0.0]], [[1113, -0.484375]], [[3369, -0.859375]], [[6908, -0.046875]], [[1065, -0.6875]], [[1040, -0.125]], [[6204, -0.265625]], [[1316, 0.0]], [[781, 0.0]], [[29473, 0.0]], [[1113, 0.0]], [[4224, 0.0]], [[2032, 0.0]], [[1113, 0.0]], [[3369, -0.3125]], [[6908, 0.0]], [[1065, 0.0]], [[1040, 0.0]], [[6204, 0.0]], [[1316, 0.0]], [[781, 0.0]], [[29473, 0.0]], [[1113, 0.0]], [[11788, 0.0]], [[2032, 0.0]], [[29473, -0.015625]], [[29542, -0.515625]], [[29542, 0.0]], [[29491, 0.0]], [[29542, 0.0]], [[29542, 0.0]], [[29493, 0.0]], [[781, -0.421875]], [[29473, 0.0]], [[1113, 0.0]], [[4267, 0.0]], [[2032, 0.0]], [[1113, 0.0]], [[1486, -0.140625]], [[1107, 0.0]], [[1117, -0.0625]], [[1032, 0.0]], [[18342, -1.015625]], [[1072, -0.328125]], [[5440, -0.90625]], [[5062, -0.15625]], [[1557, 0.0]], [[29501, -0.265625]], [[3047, 0.0]], [[2499, 0.0]], [[3169, -0.1875]], [[21925, 0.0]], [[1163, -0.1875]], [[1032, -0.578125]], [[24188, -0.765625]], [[1067, 0.0]], [[5701, -0.03125]], [[1072, -0.359375]], [[8160, -0.96875]], [[1290, 0.0]], [[4416, -0.453125]], [[1072, -0.109375]], [[2103, 0.0]], [[29493, -0.265625]], [[8952, -0.53125]], [[1122, 0.0]], [[4775, -0.4375]], [[4406, -0.015625]], [[16680, 0.0]], [[1379, -0.03125]], [[781, 0.0]], [[29520, 0.0]], [[2, 0.0]]], tokens=[1139, 781, 29473, 1113, 11363, 29481, 2032, 8135, 23353, 5613, 2378, 1192, 16311, 8271, 9651, 781, 29473, 1113, 1891, 2032, 1113, 3369, 6908, 1065, 1040, 6204, 1316, 781, 29473, 1113, 4224, 2032, 1113, 3369, 6908, 1065, 1040, 6204, 1316, 781, 29473, 1113, 11788, 2032, 29473, 29542, 29542, 29491, 29542, 29542, 29493, 781, 29473, 1113, 4267, 2032, 1113, 1486, 1107, 1117, 1032, 18342, 1072, 5440, 5062, 1557, 29501, 3047, 2499, 3169, 21925, 1163, 1032, 24188, 1067, 5701, 1072, 8160, 1290, 4416, 1072, 2103, 29493, 8952, 1122, 4775, 4406, 16680, 1379, 781, 29520, 2]), message=ChatCompletionMessage(content='{\\n \"materials\": [\"silky matte grey metal\"],\\n \"type\": \"Not specified in the description\",\\n \"color\": \"Not specified in the description\",\\n \"price\": 99.99,\\n \"summary\": \"Stay is a flexible and beautiful Danish-designed table lamp with a discreet switch and adjustable arm and head, ideal for office task lighting.\"\\n}', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None)), feature_weights=None, proba=None, score=None, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='{\\n \"materials\": [\"silky matte grey metal\"],\\n \"type\": \"Not specified in the description\",\\n \"color\": \"Not specified in the description\",\\n \"price\": 99.99,\\n \"summary\": \"Stay is a flexible and beautiful Danish-designed table lamp with a discreet switch and adjustable arm and head, ideal for office task lighting.\"\\n}', spans=[('0-{', [(0, 1)], 1.0), ('1-\\n', [(1, 2)], 1.0), ('2- ', [(2, 3)], 0.9692332344763441), ('3- \"', [(3, 5)], 1.0), ('5-material', [(5, 13)], 1.0), ('13-s', [(13, 14)], 1.0), ('14-\":', [(14, 16)], 1.0), ('16- [\"', [(16, 19)], 1.0), ('19-sil', [(19, 22)], 0.9248488132162048), ('22-ky', [(22, 24)], 1.0), ('24- mat', [(24, 28)], 1.0), ('28-te', [(28, 30)], 1.0), ('30- grey', [(30, 35)], 1.0), ('35- metal', [(35, 41)], 1.0), ('41-\"],', [(41, 44)], 0.9844964370054085), ('44-\\n', [(44, 45)], 1.0), ('45- ', [(45, 46)], 1.0), ('46- \"', [(46, 48)], 1.0), ('48-type', [(48, 52)], 1.0), ('52-\":', [(52, 54)], 1.0), ('54- \"', [(54, 56)], 0.6160821277906783), ('56-Not', [(56, 59)], 0.4234266412852628), ('59- specified', [(59, 69)], 0.9542066659691884), ('69- in', [(69, 72)], 0.5028315779709409), ('72- the', [(72, 76)], 0.8824969025845955), ('76- description', [(76, 88)], 0.76672659607082), ('88-\",', [(88, 90)], 1.0), ('90-\\n', [(90, 91)], 1.0), ('91- ', [(91, 92)], 1.0), ('92- \"', [(92, 94)], 1.0), ('94-color', [(94, 99)], 1.0), ('99-\":', [(99, 101)], 1.0), ('101- \"', [(101, 103)], 1.0), ('103-Not', [(103, 106)], 0.7316156289466418), ('106- specified', [(106, 116)], 1.0), ('116- in', [(116, 119)], 1.0), ('119- the', [(119, 123)], 1.0), ('123- description', [(123, 135)], 1.0), ('135-\",', [(135, 137)], 1.0), ('137-\\n', [(137, 138)], 1.0), ('138- ', [(138, 139)], 1.0), ('139- \"', [(139, 141)], 1.0), ('141-price', [(141, 146)], 1.0), ('146-\":', [(146, 148)], 1.0), ('148- ', [(148, 149)], 0.9844964370054085), ('149-9', [(149, 150)], 0.5971272734216274), ('150-9', [(150, 151)], 1.0), ('151-.', [(151, 152)], 1.0), ('152-9', [(152, 153)], 1.0), ('153-9', [(153, 154)], 1.0), ('154-,', [(154, 155)], 1.0), ('155-\\n', [(155, 156)], 0.6558160112715016), ('156- ', [(156, 157)], 1.0), ('157- \"', [(157, 159)], 1.0), ('159-summary', [(159, 166)], 1.0), ('166-\":', [(166, 168)], 1.0), ('168- \"', [(168, 170)], 1.0), ('170-St', [(170, 172)], 0.8688150562628432), ('172-ay', [(172, 174)], 1.0), ('174- is', [(174, 177)], 0.9394130628134758), ('177- a', [(177, 179)], 1.0), ('179- flexible', [(179, 188)], 0.3621759990808257), ('188- and', [(188, 192)], 0.7202729799554398), ('192- beautiful', [(192, 202)], 0.4040365236633421), ('202- Dan', [(202, 206)], 0.8553453273074225), ('206-ish', [(206, 209)], 1.0), ('209--', [(209, 210)], 0.76672659607082), ('210-des', [(210, 213)], 1.0), ('213-igned', [(213, 218)], 1.0), ('218- table', [(218, 224)], 0.8290291181804004), ('224- lamp', [(224, 229)], 1.0), ('229- with', [(229, 234)], 0.8290291181804004), ('234- a', [(234, 236)], 0.5609491608144708), ('236- discre', [(236, 243)], 0.4650431881340563), ('243-et', [(243, 245)], 1.0), ('245- switch', [(245, 252)], 0.9692332344763441), ('252- and', [(252, 256)], 0.6981125100681258), ('256- adjust', [(256, 263)], 0.3795571881830896), ('263-able', [(263, 267)], 1.0), ('267- arm', [(267, 271)], 0.635638673826052), ('271- and', [(271, 275)], 0.8963942066351505), ('275- head', [(275, 280)], 1.0), ('280-,', [(280, 281)], 0.76672659607082), ('281- ideal', [(281, 287)], 0.5878696731223465), ('287- for', [(287, 291)], 1.0), ('291- office', [(291, 298)], 0.645648526427892), ('298- task', [(298, 303)], 0.9844964370054085), ('303- lighting', [(303, 312)], 1.0), ('312-.\"', [(312, 314)], 0.9692332344763441), ('314-\\n', [(314, 315)], 1.0), ('315-}', [(315, 316)], 1.0)], preserve_density=False, with_probabilities=True, vec_name=None)], other=None), heatmap=None)], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client_custom = openai.OpenAI(base_url=\"http://localhost:8080/v1\", api_key=\"dummy\")\n", + "eli5.explain_prediction(\n", + " client_custom,\n", + " prompt + ' Price should never be zero.',\n", + " model=\"mlx-community/Mistral-7B-Instruct-v0.3-4bit\",\n", + ")" + ] } ], "metadata": { diff --git a/tests/test_llm_explain_prediction.py b/tests/test_llm_explain_prediction.py index a8faaee..a7c7d12 100644 --- a/tests/test_llm_explain_prediction.py +++ b/tests/test_llm_explain_prediction.py @@ -3,6 +3,7 @@ from unittest.mock import Mock pytest.importorskip('openai') +pytest.importorskip('transformers') from openai.types.chat.chat_completion import ( ChoiceLogprobs, ChatCompletion, @@ -11,6 +12,7 @@ Choice, ) from openai import Client +import transformers import eli5 from eli5.base import Explanation @@ -40,20 +42,28 @@ def example_logprobs(): @pytest.fixture def example_completion(example_logprobs): + return create_completion( + model='gpt-4o-2024-08-06', + logprobs=example_logprobs, + message=ChatCompletionMessage( + content=''.join(x.token for x in example_logprobs.content), + role='assistant', + ), + ) + + +def create_completion(model, logprobs, message): return ChatCompletion( id='chatcmpl-x', created=1743590849, - model='gpt-4o-2024-08-06', + model=model, object='chat.completion', choices=[ Choice( - logprobs=example_logprobs, + logprobs=logprobs, finish_reason='stop', index=0, - message=ChatCompletionMessage( - content=''.join(x.token for x in example_logprobs.content), - role='assistant', - ), + message=message, ) ], ) @@ -100,7 +110,35 @@ def __init__(self, chat_return_value): def test_explain_prediction_openai_client(monkeypatch, example_completion): client = MockClient(example_completion) - explanation = eli5.explain_prediction(client, doc="Hello world", model="gpt-4o") + explanation = eli5.explain_prediction(client, doc="Hello world world", model="gpt-4o") + _assert_explanation_structure_and_html(explanation) + + client.chat.completions.create.assert_called_once() + + +def test_explain_prediction_openai_client_mlx(monkeypatch): + model = "mlx-community/Mistral-7B-Instruct-v0.3-4bit" + tokenizer = transformers.AutoTokenizer.from_pretrained(model) + + text = 'Hello world world' + tokens = tokenizer.encode(text, add_special_tokens=False) + assert len(tokens) == 3 + logprobs = ChoiceLogprobs( + token_logprobs=[ + math.log(0.9), + math.log(0.2), + math.log(0.4), + ], + tokens=tokens, + ) + completion = create_completion( + model=model, + logprobs=logprobs, + message=ChatCompletionMessage(content=text, role='assistant'), + ) + client = MockClient(completion) + + explanation = eli5.explain_prediction(client, doc=text, model=model) _assert_explanation_structure_and_html(explanation) client.chat.completions.create.assert_called_once() diff --git a/tox.ini b/tox.ini index d071724..acae046 100644 --- a/tox.ini +++ b/tox.ini @@ -30,6 +30,7 @@ deps= pandas sklearn-crfsuite openai + tokenizers commands= pip install -e . py.test --doctest-modules \ From 8f883c11f7072c5a7c324bdeb79714f782dd881c Mon Sep 17 00:00:00 2001 From: Konstantin Lopuhin Date: Sun, 6 Apr 2025 09:47:27 +0100 Subject: [PATCH 2/2] fix typecheck --- eli5/llm/explain_prediction.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/eli5/llm/explain_prediction.py b/eli5/llm/explain_prediction.py index fdc3666..443d1e6 100644 --- a/eli5/llm/explain_prediction.py +++ b/eli5/llm/explain_prediction.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Union +from typing import Optional, Union import openai from openai.types.chat.chat_completion import ( @@ -99,21 +99,27 @@ def explain_prediction_openai_client( model=model, **kwargs) for choice in completion.choices: - logprobs = choice.logprobs - if logprobs is None: + _recover_logprobs(choice.logprobs, model) + if choice.logprobs is None: raise ValueError('logprobs not found, likely API does not support them') - if logprobs.content is None: - _recover_logprobs_content(logprobs, model) - if logprobs.content is None: - raise ValueError(f'logprobs.content is empty: {logprobs}') + if choice.logprobs.content is None: + raise ValueError(f'logprobs.content is empty: {choice.logprobs}') return explain_prediction_openai_completion(completion) -def _recover_logprobs_content(logprobs: ChoiceLogprobs, model: str): +def _recover_logprobs(logprobs: Optional[ChoiceLogprobs], model: str): """ Some servers don't populate logprobs.content, try to recover it. """ - if not (logprobs.token_logprobs and logprobs.tokens): + if logprobs is None: + return + if logprobs.content is not None: + return + if not ( + getattr(logprobs, 'token_logprobs', None) and + getattr(logprobs, 'tokens', None)): return + assert hasattr(logprobs, 'token_logprobs') # for mypy + assert hasattr(logprobs, 'tokens') # for mypy try: import tokenizers except ImportError: