-
-
Notifications
You must be signed in to change notification settings - Fork 2
Fix WordStoppingCriteria device mismatch #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: AmitMY <5757359+AmitMY@users.noreply.github.com>
Co-authored-by: AmitMY <5757359+AmitMY@users.noreply.github.com>
Co-authored-by: AmitMY <5757359+AmitMY@users.noreply.github.com>
AmitMY
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot please change based on my notes
words_segmentation/pretokenizer.py
Outdated
| texts = [self.tokenizer.decode(ids) for ids in input_ids] | ||
| is_done = [is_word_complete(text) for text in texts] | ||
| return torch.BoolTensor(is_done) | ||
| return torch.tensor(is_done, dtype=torch.bool, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use device=input_ids.device directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit 58d73a7. Removed the intermediate device variable and now using device=input_ids.device directly in the tensor constructor.
tests/test_pretokenizer.py
Outdated
| assert not is_word_complete("こんにちは") | ||
|
|
||
|
|
||
| def test_word_stopping_criteria_basic(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move these tests to a tests/test_word_stopping_criteria.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit 58d73a7. Moved all WordStoppingCriteria tests and the MockTokenizer class to tests/test_word_stopping_criteria.py.
… file Co-authored-by: AmitMY <5757359+AmitMY@users.noreply.github.com>
WordStoppingCriteria.__call__createstorch.BoolTensorwithout specifying device, causing device mismatch when input_ids is on non-CPU device.gitignoreentries for__pycache__filesWordStoppingCriteriathat verify current behavior__call__method to respect the device ofinput_idsby passing it totorch.tensor()device=input_ids.devicedirectly instead of intermediate variabletests/test_word_stopping_criteria.pySummary
Successfully addressed review feedback:
device = input_ids.devicefollowed bydevice=deviceto directly usingdevice=input_ids.deviceinlinetests/test_word_stopping_criteria.pywith all 5 WordStoppingCriteria tests and MockTokenizer classAll 51 tests pass, linting passes with no issues.
Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.