Skip to content

jax.utils.safe_zip -> list(zip(*args, strict=True))#18

Merged
ASEM000 merged 3 commits intoASEM000:mainfrom
Nin17:safe_zip
Nov 15, 2025
Merged

jax.utils.safe_zip -> list(zip(*args, strict=True))#18
ASEM000 merged 3 commits intoASEM000:mainfrom
Nin17:safe_zip

Conversation

@Nin17
Copy link
Contributor

@Nin17 Nin17 commented Nov 14, 2025

Fixes #17.

Tests passing with jax 0.8.0

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR replaces the deprecated jax.util.safe_zip with a custom implementation using Python's built-in zip(*args, strict=True), enabling compatibility with JAX 0.8.0.

  • Implements safe_zip locally in kernex/_src/utils.py
  • Updates minimum Python version from 3.8 to 3.10 (required for zip(strict=True))
  • Updates imports to use the new local implementation

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
kernex/_src/utils.py Adds local safe_zip implementation using zip(*args, strict=True)
kernex/interface/resolve_utils.py Updates import to use local safe_zip from kernex._src.utils instead of deprecated jax.util.safe_zip
setup.py Updates minimum Python version requirement to 3.10 to support zip(strict=True)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

}


def safe_zip(*args):
Copy link

Copilot AI Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The safe_zip function lacks documentation. Consider adding a docstring that explains:

  • The purpose of the function (ensuring all iterables have the same length)
  • Parameters and return type
  • Example usage

For example:

def safe_zip(*args):
    """Zip iterables together, ensuring they all have the same length.
    
    Args:
        *args: Variable number of iterables to zip together.
        
    Returns:
        List of tuples containing elements from the input iterables.
        
    Raises:
        ValueError: If the input iterables have different lengths.
    """
    return list(zip(*args, strict=True))
Suggested change
def safe_zip(*args):
def safe_zip(*args):
"""Zip iterables together, ensuring they all have the same length.
Args:
*args: Variable number of iterables to zip together.
Returns:
List of tuples containing elements from the input iterables.
Raises:
ValueError: If the input iterables have different lengths.
Example:
>>> safe_zip([1, 2], ['a', 'b'])
[(1, 'a'), (2, 'b')]
"""

Copilot uses AI. Check for mistakes.
@ASEM000 ASEM000 merged commit 9abc556 into ASEM000:main Nov 15, 2025
3 checks passed
@ASEM000
Copy link
Owner

ASEM000 commented Nov 15, 2025

Thanks @Nin17

@Nin17 Nin17 deleted the safe_zip branch November 15, 2025 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.util.safe_zip deprecated and removed.

3 participants