Skip to content

Hardcoded GENERATION_REGEX logic (with return_assistant_tokens_mask) causes incorret masks production for chat examples in _chat_preprocess function #2598

@DzmitryPihulski

Description

@DzmitryPihulski

Describe the bug

The function _chat_preprocess is used to apply chat template to the hf chat examples.

The examples are in the format of messages. To correctly produce the mask and train only on assistant messages, the user/system/tool messages need to be masked (mask=0). The function logic is:

template_has_generation_kwd = GENERATION_REGEX.search(tokenizer.chat_template) is not None

tokenized_chat = tokenizer.apply_chat_template(
      chat,
      tools=tools,
      tokenize=True,
      return_dict=True,
      return_assistant_tokens_mask=template_has_generation_kwd,
  )

input_ids = tokenized_chat.get("input_ids")
if template_has_generation_kwd:
    mask = tokenized_chat["assistant_masks"]
else:
    mask = [1] * len(input_ids)

The issue here is that not all tokenizers has {% generation %} keyword, for example Nemotron 3 Nano model doesn't, which means we will get default attention_mask (all 1's). Such mask is not the best for SFT, cause it forces the model to be trained on every user/system/tool message.

As far as I understand this issue is only for hf chats, cause for not hf chats there is _preprocess function from the same file, that uses _mask_targets function to correctly mask messages.

Expected behavior

Maybe remove the return_assistant_tokens_mask parameter, it doesn't help with the modern chat templates and apply own logic of masking.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions