Skip to content

Conversation

@coder0143
Copy link
Contributor

@coder0143 coder0143 commented Dec 11, 2025

Resolves #98

Reference

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@coder0143
Copy link
Contributor Author

Can you please review my PR @chapman20j

@chapman20j
Copy link
Collaborator

Hi @coder0143. Thanks for the nice PR! I left a few comments. Having explicit configs here can help make it more clear what hyperparameters are used in constructing the model and could simplify some parts of the code. Also, including more testing ensures model correctness. Looking forward to the final version!

@coder0143
Copy link
Contributor Author

Thankyou so much for reviewing and replying @chapman20j , I have made the following changes:

  • Written custom configs for loading the specific models (modeling.py)
  • Removed transformers dependency from modeling and params
  • Removed cosine similarity and added more tests to the test_outputs file
  • Updated the colab notebook

@jenriver
Copy link
Member

Hi, could you ensure that the tests above are passing? i.e.

  1. Please ensure you have run pre-commit run --all-files as in contribution guidelines.
  2. The CI is currently failing with a GatedRepoError because the dinov3 checkpoint is restricted. Could you update the test to use randomly initialized weights instead? (Note: Please ensure the JAX and PyTorch models are initialized with the same random weights so the parity assertions still pass. Creating a random PyTorch model and converting it to JAX within setUp usually works best.)

@coder0143
Copy link
Contributor Author

coder0143 commented Dec 17, 2025

Thankyou for reviewing @jenriver , I have made the following changes:

  • Checked and updated files based on ruff formatting.
  • Removed run_model.py file, instead colab notebook can be used.
  • Updated test_outputs.py file to use a randomly initialized model and run tests based on that, using the vit_b16 model
  • the pre-commit command runs fine

@jenriver
Copy link
Member

Hi, we're still seeing pre-commit failures as above -- could you ensure you have run pre-commit hooks?

i.e.
pre-commit run --all-files as in contribution guidelines.

Comment on lines 18 to 19
raw_path = "~/.cache/huggingface/dinov3_vitb16"
self.save_dir = os.path.expanduser(raw_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you not use local directory paths?

i.e. Something like

self.save_dir = snapshot_download(...)

@coder0143
Copy link
Contributor Author

I have made the necessary changes and everything will pass now, thankyou for reviewing and guiding my pr @chapman20j and @jenriver, btw I have sent connect request on LinkedIn!

np_y = np.asarray(jax.device_get(jy))
ty_bonsai = torch.tensor(np_y, dtype=torch.float32)

torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=3e-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a high tolerance. If RoPE casting and LayerNorms are correctly aligned, we should be seeing a value much tighter than this.

@coder0143
Copy link
Contributor Author

Yeah, actually things are working just fine, I just updated the atol values and tested it many times, for first layer, setting atol to 2e-3 is ok but in a very worse case, we get atol as 0.0024 max, I have also updated and tested for other output functions. Actually pytorch casts to bfloat16 for RoPE calculation and then casts it back to float32, I have also done the same in jax, but mostly this is where error is introduced, other than some other fp operations

@chapman20j chapman20j merged commit 9ccb175 into jax-ml:main Dec 23, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Request to add Dinov3 ViT models

3 participants