Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import re
import string
import warnings
from collections import defaultdict
from typing import List, Union

Expand Down Expand Up @@ -133,6 +134,15 @@ def load_model(
if any("ForTokenClassification" in ar for ar in config.architectures)
else AutoModelForCausalLM
)
if "cuda" in device_map and not torch.cuda.is_available():
warnings.warn(
f"device_map='{device_map}' was requested but CUDA is not available "
f"(torch.cuda.is_available() is False). Falling back to 'cpu'. "
f"To silence this warning, pass device_map='cpu' explicitly.",
RuntimeWarning,
stacklevel=2,
)
device_map = "cpu"
self.device = (
device_map
if any(key in device_map for key in ["cuda", "cpu", "mps"])
Expand Down
173 changes: 173 additions & 0 deletions tests/test_issue_216.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]

"""Regression test for https://github.com/microsoft/LLMLingua/issues/216

PromptCompressor defaulted to device_map="cuda" even when torch was built
without CUDA support, producing "AssertionError: Torch not compiled with CUDA
enabled" on Windows / CPU-only machines. The fix transparently falls back to
"cpu" (with a RuntimeWarning) when "cuda" is requested but unavailable.

These tests exercise the fallback branch directly using monkeypatching so they
run in a fraction of a second and do not depend on a real CUDA runtime.
"""

import unittest
import warnings
from unittest import mock

import torch

import llmlingua.prompt_compressor as prompt_compressor_module
from llmlingua import PromptCompressor


class _FakeConfig:
architectures = ["BertForTokenClassification"]
pad_token_id = 0
max_position_embeddings = 512


class _FakeTokenizer:
pad_token_id = 0
eos_token_id = 0
padding_side = "right"
special_tokens_map = {"pad_token": "[PAD]"}

def add_special_tokens(self, mapping):
return 0


class _FakeModel:
def resize_token_embeddings(self, new_len):
return None

def to(self, device):
return self

def eval(self):
return self


class Issue216CudaFallbackTest(unittest.TestCase):
"""Verify that PromptCompressor falls back to CPU when CUDA is unavailable."""

def _patched_load_model(
self,
cuda_available: bool,
requested_device_map: str,
):
captured = {}

def fake_from_pretrained(*args, **kwargs):
name = args[0] if args else kwargs.get("model_name", "")
if "Config" in fake_from_pretrained._cls_name:
return _FakeConfig()
if "Tokenizer" in fake_from_pretrained._cls_name:
return _FakeTokenizer()
captured["device_map"] = kwargs.get("device_map")
return _FakeModel()

fake_from_pretrained._cls_name = ""

def build_fake(cls_name):
def _factory(*args, **kwargs):
if cls_name == "Config":
return _FakeConfig()
if cls_name == "Tokenizer":
return _FakeTokenizer()
captured["device_map"] = kwargs.get("device_map")
return _FakeModel()

return mock.Mock(from_pretrained=_factory)

with mock.patch.object(
torch.cuda, "is_available", return_value=cuda_available
), mock.patch.object(
prompt_compressor_module, "AutoConfig", build_fake("Config")
), mock.patch.object(
prompt_compressor_module, "AutoTokenizer", build_fake("Tokenizer")
), mock.patch.object(
prompt_compressor_module,
"AutoModelForTokenClassification",
build_fake("Model"),
), mock.patch.object(
prompt_compressor_module, "AutoModelForCausalLM", build_fake("Model")
):
compressor = PromptCompressor.__new__(PromptCompressor)
compressor.load_model(
"microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
device_map=requested_device_map,
)
return compressor, captured

def test_cuda_requested_but_unavailable_falls_back_to_cpu(self):
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter("always")
compressor, passed_kwargs = self._patched_load_model(
cuda_available=False, requested_device_map="cuda"
)

self.assertEqual(compressor.device, "cpu")
self.assertEqual(passed_kwargs["device_map"], "cpu")

runtime_warnings = [
w
for w in captured_warnings
if issubclass(w.category, RuntimeWarning)
and "CUDA is not available" in str(w.message)
]
self.assertEqual(
len(runtime_warnings),
1,
f"expected exactly one CUDA fallback RuntimeWarning, got "
f"{len(runtime_warnings)}: {[str(w.message) for w in captured_warnings]}",
)

def test_cpu_requested_explicitly_is_untouched(self):
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter("always")
compressor, passed_kwargs = self._patched_load_model(
cuda_available=False, requested_device_map="cpu"
)

self.assertEqual(compressor.device, "cpu")
self.assertEqual(passed_kwargs["device_map"], "cpu")

runtime_warnings = [
w
for w in captured_warnings
if issubclass(w.category, RuntimeWarning)
and "CUDA" in str(w.message)
]
self.assertEqual(
len(runtime_warnings),
0,
"explicit device_map='cpu' should not trigger any CUDA warning",
)

def test_cuda_requested_and_available_stays_on_cuda(self):
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter("always")
compressor, passed_kwargs = self._patched_load_model(
cuda_available=True, requested_device_map="cuda"
)

self.assertEqual(compressor.device, "cuda")
self.assertEqual(passed_kwargs["device_map"], "cuda")

runtime_warnings = [
w
for w in captured_warnings
if issubclass(w.category, RuntimeWarning)
and "CUDA is not available" in str(w.message)
]
self.assertEqual(
len(runtime_warnings),
0,
"no fallback warning should be raised when CUDA is actually available",
)


if __name__ == "__main__":
unittest.main()