Skip to content

Conversation

@maharajamihir
Copy link

No description provided.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors checkpoint loading to use the new Orbax CheckpointManager with TrainState and introduces a Grain-based dataloader initialization helper.

  • Swap out PyTreeCheckpointer for Orbax CheckpointManager and shape-only dummy state restoration
  • Add checkpoint_step and lam_co_train flags to Args
  • Extract dataloader setup into _get_dataloader_iterator() using Grain’s iterator API

maharajamihir and others added 3 commits July 14, 2025 18:54
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants