Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions indextts/gpt/transformers_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_token_type_ids,
)
Expand Down Expand Up @@ -104,6 +103,42 @@
)


# Compatibility function for older transformers versions
def _crop_past_key_values(model, past_key_values, new_cache_size):
"""
Crop past_key_values to match new_cache_size for compatibility with newer transformers versions.
"""
if past_key_values is None:
return None

cropped_past_key_values = []
for layer_past in past_key_values:
if layer_past is None:
cropped_past_key_values.append(None)
continue

# Each layer_past should be a tuple of (key, value) tensors
if isinstance(layer_past, (tuple, list)) and len(layer_past) == 2:
key, value = layer_past
# Crop along the sequence length dimension (usually dim=2)
if key is not None and key.size(2) > new_cache_size:
cropped_key = key[:, :, :new_cache_size, :]
else:
cropped_key = key

if value is not None and value.size(2) > new_cache_size:
cropped_value = value[:, :, :new_cache_size, :]
else:
cropped_value = value

cropped_past_key_values.append((cropped_key, cropped_value))
else:
# Fallback for unexpected formats
cropped_past_key_values.append(layer_past)

return tuple(cropped_past_key_values)


if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Expand Down Expand Up @@ -1002,7 +1037,8 @@ def _get_logits_processor(
device=device,
)
)
if generation_config.forced_decoder_ids is not None:
# Check for forced_decoder_ids attribute compatibility with newer transformers versions
if hasattr(generation_config, 'forced_decoder_ids') and generation_config.forced_decoder_ids is not None:
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
raise ValueError(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
Expand Down
2 changes: 1 addition & 1 deletion indextts/gpt/transformers_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from indextts.gpt.transformers_generation_utils import GenerationMixin
from indextts.gpt.transformers_modeling_utils import PreTrainedModel
from transformers.modeling_utils import SequenceSummary
from transformers.models.gpt2.modeling_gpt2 import GPT2SequenceSummary as SequenceSummary

from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import (
Expand Down
8 changes: 1 addition & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ sentencepiece
tqdm
textstat

pynini==2.1.6; platform_system!="Windows"
WeTextProcessing>=1.0.3; platform_system!="Windows"

WeTextProcessing; platform_machine != "Darwin"
wetext; platform_system == "Darwin"

# importlib_resources
# pynini==2.1.6.post1
# WeTextProcessing>=1.0.4
# deepspeed # Use it to accelerate model inference
# deepspeed # Use it to accelerate model inference