diff --git a/docs/source/data.rst b/docs/source/data.rst index 0297d14..2a84007 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -9,6 +9,9 @@ Data .. autoclass:: deepdiagnostics.data.H5Data :members: +.. autoclass:: deepdiagnostics.data.H5HierarchyData + :members: + .. autoclass:: deepdiagnostics.data.PickleData :members: diff --git a/docs/source/models.rst b/docs/source/models.rst index c4338da..7b3d4fe 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -8,3 +8,6 @@ Models .. autoclass:: deepdiagnostics.models.SBIModel :members: + +.. autoclass:: deepdiagnostics.models.HierarchyModel + :members: diff --git a/docs/source/plots.rst b/docs/source/plots.rst index 1069c2f..26a8de9 100644 --- a/docs/source/plots.rst +++ b/docs/source/plots.rst @@ -8,9 +8,15 @@ Plots .. autoclass:: deepdiagnostics.plots.CDFRanks :members: plot +.. autoclass:: deepdiagnostics.plots.HierarchyCDFRanks + :members: plot + .. autoclass:: deepdiagnostics.plots.Ranks :members: plot +.. autoclass:: deepdiagnostics.plots.HierarchyRanks + :members: plot + .. autoclass:: deepdiagnostics.plots.CoverageFraction :members: plot @@ -30,4 +36,7 @@ Plots .. autoclass:: deepdiagnostics.plots.Parity :members: plot +.. autoclass:: deepdiagnostics.plots.HierarchyParity + :members: plot + .. bibliography:: \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 1ac4276..24daa36 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -12,6 +12,142 @@ files = [ {file = "absl_py-2.3.0.tar.gz", hash = "sha256:d96fda5c884f1b22178852f30ffa85766d50b99e00775ea626c23304f582fc4f"}, ] +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"}, + {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"}, +] + +[[package]] +name = "aiohttp" +version = "3.12.15" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiohttp-3.12.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b6fc902bff74d9b1879ad55f5404153e2b33a82e72a95c89cec5eb6cc9e92fbc"}, + {file = "aiohttp-3.12.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:098e92835b8119b54c693f2f88a1dec690e20798ca5f5fe5f0520245253ee0af"}, + {file = "aiohttp-3.12.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:40b3fee496a47c3b4a39a731954c06f0bd9bd3e8258c059a4beb76ac23f8e421"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ce13fcfb0bb2f259fb42106cdc63fa5515fb85b7e87177267d89a771a660b79"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3beb14f053222b391bf9cf92ae82e0171067cc9c8f52453a0f1ec7c37df12a77"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c39e87afe48aa3e814cac5f535bc6199180a53e38d3f51c5e2530f5aa4ec58c"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5f1b4ce5bc528a6ee38dbf5f39bbf11dd127048726323b72b8e85769319ffc4"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1004e67962efabbaf3f03b11b4c43b834081c9e3f9b32b16a7d97d4708a9abe6"}, + {file = "aiohttp-3.12.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8faa08fcc2e411f7ab91d1541d9d597d3a90e9004180edb2072238c085eac8c2"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fe086edf38b2222328cdf89af0dde2439ee173b8ad7cb659b4e4c6f385b2be3d"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:79b26fe467219add81d5e47b4a4ba0f2394e8b7c7c3198ed36609f9ba161aecb"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b761bac1192ef24e16706d761aefcb581438b34b13a2f069a6d343ec8fb693a5"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e153e8adacfe2af562861b72f8bc47f8a5c08e010ac94eebbe33dc21d677cd5b"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:fc49c4de44977aa8601a00edbf157e9a421f227aa7eb477d9e3df48343311065"}, + {file = "aiohttp-3.12.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2776c7ec89c54a47029940177e75c8c07c29c66f73464784971d6a81904ce9d1"}, + {file = "aiohttp-3.12.15-cp310-cp310-win32.whl", hash = "sha256:2c7d81a277fa78b2203ab626ced1487420e8c11a8e373707ab72d189fcdad20a"}, + {file = "aiohttp-3.12.15-cp310-cp310-win_amd64.whl", hash = "sha256:83603f881e11f0f710f8e2327817c82e79431ec976448839f3cd05d7afe8f830"}, + {file = "aiohttp-3.12.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d3ce17ce0220383a0f9ea07175eeaa6aa13ae5a41f30bc61d84df17f0e9b1117"}, + {file = "aiohttp-3.12.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:010cc9bbd06db80fe234d9003f67e97a10fe003bfbedb40da7d71c1008eda0fe"}, + {file = "aiohttp-3.12.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f9d7c55b41ed687b9d7165b17672340187f87a773c98236c987f08c858145a9"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc4fbc61bb3548d3b482f9ac7ddd0f18c67e4225aaa4e8552b9f1ac7e6bda9e5"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fbc8a7c410bb3ad5d595bb7118147dfbb6449d862cc1125cf8867cb337e8728"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74dad41b3458dbb0511e760fb355bb0b6689e0630de8a22b1b62a98777136e16"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b6f0af863cf17e6222b1735a756d664159e58855da99cfe965134a3ff63b0b0"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5b7fe4972d48a4da367043b8e023fb70a04d1490aa7d68800e465d1b97e493b"}, + {file = "aiohttp-3.12.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6443cca89553b7a5485331bc9bedb2342b08d073fa10b8c7d1c60579c4a7b9bd"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c5f40ec615e5264f44b4282ee27628cea221fcad52f27405b80abb346d9f3f8"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:2abbb216a1d3a2fe86dbd2edce20cdc5e9ad0be6378455b05ec7f77361b3ab50"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:db71ce547012a5420a39c1b744d485cfb823564d01d5d20805977f5ea1345676"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ced339d7c9b5030abad5854aa5413a77565e5b6e6248ff927d3e174baf3badf7"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:7c7dd29c7b5bda137464dc9bfc738d7ceea46ff70309859ffde8c022e9b08ba7"}, + {file = "aiohttp-3.12.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:421da6fd326460517873274875c6c5a18ff225b40da2616083c5a34a7570b685"}, + {file = "aiohttp-3.12.15-cp311-cp311-win32.whl", hash = "sha256:4420cf9d179ec8dfe4be10e7d0fe47d6d606485512ea2265b0d8c5113372771b"}, + {file = "aiohttp-3.12.15-cp311-cp311-win_amd64.whl", hash = "sha256:edd533a07da85baa4b423ee8839e3e91681c7bfa19b04260a469ee94b778bf6d"}, + {file = "aiohttp-3.12.15-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:802d3868f5776e28f7bf69d349c26fc0efadb81676d0afa88ed00d98a26340b7"}, + {file = "aiohttp-3.12.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f2800614cd560287be05e33a679638e586a2d7401f4ddf99e304d98878c29444"}, + {file = "aiohttp-3.12.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8466151554b593909d30a0a125d638b4e5f3836e5aecde85b66b80ded1cb5b0d"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e5a495cb1be69dae4b08f35a6c4579c539e9b5706f606632102c0f855bcba7c"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6404dfc8cdde35c69aaa489bb3542fb86ef215fc70277c892be8af540e5e21c0"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ead1c00f8521a5c9070fcb88f02967b1d8a0544e6d85c253f6968b785e1a2ab"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6990ef617f14450bc6b34941dba4f12d5613cbf4e33805932f853fbd1cf18bfb"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd736ed420f4db2b8148b52b46b88ed038d0354255f9a73196b7bbce3ea97545"}, + {file = "aiohttp-3.12.15-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c5092ce14361a73086b90c6efb3948ffa5be2f5b6fbcf52e8d8c8b8848bb97c"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aaa2234bb60c4dbf82893e934d8ee8dea30446f0647e024074237a56a08c01bd"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:6d86a2fbdd14192e2f234a92d3b494dd4457e683ba07e5905a0b3ee25389ac9f"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a041e7e2612041a6ddf1c6a33b883be6a421247c7afd47e885969ee4cc58bd8d"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5015082477abeafad7203757ae44299a610e89ee82a1503e3d4184e6bafdd519"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:56822ff5ddfd1b745534e658faba944012346184fbfe732e0d6134b744516eea"}, + {file = "aiohttp-3.12.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b2acbbfff69019d9014508c4ba0401822e8bae5a5fdc3b6814285b71231b60f3"}, + {file = "aiohttp-3.12.15-cp312-cp312-win32.whl", hash = "sha256:d849b0901b50f2185874b9a232f38e26b9b3d4810095a7572eacea939132d4e1"}, + {file = "aiohttp-3.12.15-cp312-cp312-win_amd64.whl", hash = "sha256:b390ef5f62bb508a9d67cb3bba9b8356e23b3996da7062f1a57ce1a79d2b3d34"}, + {file = "aiohttp-3.12.15-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9f922ffd05034d439dde1c77a20461cf4a1b0831e6caa26151fe7aa8aaebc315"}, + {file = "aiohttp-3.12.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ee8a8ac39ce45f3e55663891d4b1d15598c157b4d494a4613e704c8b43112cd"}, + {file = "aiohttp-3.12.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3eae49032c29d356b94eee45a3f39fdf4b0814b397638c2f718e96cfadf4c4e4"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b97752ff12cc12f46a9b20327104448042fce5c33a624f88c18f66f9368091c7"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:894261472691d6fe76ebb7fcf2e5870a2ac284c7406ddc95823c8598a1390f0d"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5fa5d9eb82ce98959fc1031c28198b431b4d9396894f385cb63f1e2f3f20ca6b"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0fa751efb11a541f57db59c1dd821bec09031e01452b2b6217319b3a1f34f3d"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5346b93e62ab51ee2a9d68e8f73c7cf96ffb73568a23e683f931e52450e4148d"}, + {file = "aiohttp-3.12.15-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:049ec0360f939cd164ecbfd2873eaa432613d5e77d6b04535e3d1fbae5a9e645"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b52dcf013b57464b6d1e51b627adfd69a8053e84b7103a7cd49c030f9ca44461"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9b2af240143dd2765e0fb661fd0361a1b469cab235039ea57663cda087250ea9"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ac77f709a2cde2cc71257ab2d8c74dd157c67a0558a0d2799d5d571b4c63d44d"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:47f6b962246f0a774fbd3b6b7be25d59b06fdb2f164cf2513097998fc6a29693"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:760fb7db442f284996e39cf9915a94492e1896baac44f06ae551974907922b64"}, + {file = "aiohttp-3.12.15-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad702e57dc385cae679c39d318def49aef754455f237499d5b99bea4ef582e51"}, + {file = "aiohttp-3.12.15-cp313-cp313-win32.whl", hash = "sha256:f813c3e9032331024de2eb2e32a88d86afb69291fbc37a3a3ae81cc9917fb3d0"}, + {file = "aiohttp-3.12.15-cp313-cp313-win_amd64.whl", hash = "sha256:1a649001580bdb37c6fdb1bebbd7e3bc688e8ec2b5c6f52edbb664662b17dc84"}, + {file = "aiohttp-3.12.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:691d203c2bdf4f4637792efbbcdcd157ae11e55eaeb5e9c360c1206fb03d4d98"}, + {file = "aiohttp-3.12.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8e995e1abc4ed2a454c731385bf4082be06f875822adc4c6d9eaadf96e20d406"}, + {file = "aiohttp-3.12.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bd44d5936ab3193c617bfd6c9a7d8d1085a8dc8c3f44d5f1dcf554d17d04cf7d"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46749be6e89cd78d6068cdf7da51dbcfa4321147ab8e4116ee6678d9a056a0cf"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0c643f4d75adea39e92c0f01b3fb83d57abdec8c9279b3078b68a3a52b3933b6"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0a23918fedc05806966a2438489dcffccbdf83e921a1170773b6178d04ade142"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:74bdd8c864b36c3673741023343565d95bfbd778ffe1eb4d412c135a28a8dc89"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a146708808c9b7a988a4af3821379e379e0f0e5e466ca31a73dbdd0325b0263"}, + {file = "aiohttp-3.12.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7011a70b56facde58d6d26da4fec3280cc8e2a78c714c96b7a01a87930a9530"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:3bdd6e17e16e1dbd3db74d7f989e8af29c4d2e025f9828e6ef45fbdee158ec75"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:57d16590a351dfc914670bd72530fd78344b885a00b250e992faea565b7fdc05"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:bc9a0f6569ff990e0bbd75506c8d8fe7214c8f6579cca32f0546e54372a3bb54"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:536ad7234747a37e50e7b6794ea868833d5220b49c92806ae2d7e8a9d6b5de02"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:f0adb4177fa748072546fb650d9bd7398caaf0e15b370ed3317280b13f4083b0"}, + {file = "aiohttp-3.12.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14954a2988feae3987f1eb49c706bff39947605f4b6fa4027c1d75743723eb09"}, + {file = "aiohttp-3.12.15-cp39-cp39-win32.whl", hash = "sha256:b784d6ed757f27574dca1c336f968f4e81130b27595e458e69457e6878251f5d"}, + {file = "aiohttp-3.12.15-cp39-cp39-win_amd64.whl", hash = "sha256:86ceded4e78a992f835209e236617bffae649371c4a50d5e5a3987f237db84b8"}, + {file = "aiohttp-3.12.15.tar.gz", hash = "sha256:4fc61385e9c98d72fcdf47e6dd81833f47b2f77c114c29cd64a361be57a763a2"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.5.0" +aiosignal = ">=1.4.0" +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +propcache = ">=0.2.0" +yarl = ">=1.17.0,<2.0" + +[package.extras] +speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.3.0)", "brotlicffi ; platform_python_implementation != \"CPython\""] + +[[package]] +name = "aiosignal" +version = "1.4.0" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e"}, + {file = "aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" +typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""} + [[package]] name = "alabaster" version = "0.7.16" @@ -52,6 +188,26 @@ xarray-einstats = ">=0.3" all = ["bokeh (>=3)", "contourpy", "dask[distributed]", "dm-tree (>=0.1.8)", "netcdf4", "numba", "ujson", "xarray-datatree", "zarr (>=2.5.0,<3)"] preview = ["arviz-base[h5netcdf]", "arviz-plots", "arviz-stats[xarray]"] +[[package]] +name = "attrs" +version = "25.3.0" +description = "Classes Without Boilerplate" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, + {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, +] + +[package.extras] +benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier"] +tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] + [[package]] name = "babel" version = "2.17.0" @@ -478,6 +634,18 @@ files = [ {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, ] +[[package]] +name = "einops" +version = "0.8.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737"}, + {file = "einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84"}, +] + [[package]] name = "filelock" version = "3.18.0" @@ -578,6 +746,120 @@ ufo = ["fs (>=2.2.0,<3)"] unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] +[[package]] +name = "frozenlist" +version = "1.7.0" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "frozenlist-1.7.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc4df77d638aa2ed703b878dd093725b72a824c3c546c076e8fdf276f78ee84a"}, + {file = "frozenlist-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:716a9973a2cc963160394f701964fe25012600f3d311f60c790400b00e568b61"}, + {file = "frozenlist-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0fd1bad056a3600047fb9462cff4c5322cebc59ebf5d0a3725e0ee78955001d"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3789ebc19cb811163e70fe2bd354cea097254ce6e707ae42e56f45e31e96cb8e"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af369aa35ee34f132fcfad5be45fbfcde0e3a5f6a1ec0712857f286b7d20cca9"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac64b6478722eeb7a3313d494f8342ef3478dff539d17002f849101b212ef97c"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f89f65d85774f1797239693cef07ad4c97fdd0639544bad9ac4b869782eb1981"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1073557c941395fdfcfac13eb2456cb8aad89f9de27bae29fabca8e563b12615"}, + {file = "frozenlist-1.7.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ed8d2fa095aae4bdc7fdd80351009a48d286635edffee66bf865e37a9125c50"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:24c34bea555fe42d9f928ba0a740c553088500377448febecaa82cc3e88aa1fa"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:69cac419ac6a6baad202c85aaf467b65ac860ac2e7f2ac1686dc40dbb52f6577"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:960d67d0611f4c87da7e2ae2eacf7ea81a5be967861e0c63cf205215afbfac59"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:41be2964bd4b15bf575e5daee5a5ce7ed3115320fb3c2b71fca05582ffa4dc9e"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:46d84d49e00c9429238a7ce02dc0be8f6d7cd0cd405abd1bebdc991bf27c15bd"}, + {file = "frozenlist-1.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:15900082e886edb37480335d9d518cec978afc69ccbc30bd18610b7c1b22a718"}, + {file = "frozenlist-1.7.0-cp310-cp310-win32.whl", hash = "sha256:400ddd24ab4e55014bba442d917203c73b2846391dd42ca5e38ff52bb18c3c5e"}, + {file = "frozenlist-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:6eb93efb8101ef39d32d50bce242c84bcbddb4f7e9febfa7b524532a239b4464"}, + {file = "frozenlist-1.7.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:aa51e147a66b2d74de1e6e2cf5921890de6b0f4820b257465101d7f37b49fb5a"}, + {file = "frozenlist-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9b35db7ce1cd71d36ba24f80f0c9e7cff73a28d7a74e91fe83e23d27c7828750"}, + {file = "frozenlist-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:34a69a85e34ff37791e94542065c8416c1afbf820b68f720452f636d5fb990cd"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a646531fa8d82c87fe4bb2e596f23173caec9185bfbca5d583b4ccfb95183e2"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:79b2ffbba483f4ed36a0f236ccb85fbb16e670c9238313709638167670ba235f"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a26f205c9ca5829cbf82bb2a84b5c36f7184c4316617d7ef1b271a56720d6b30"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bcacfad3185a623fa11ea0e0634aac7b691aa925d50a440f39b458e41c561d98"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72c1b0fe8fe451b34f12dce46445ddf14bd2a5bcad7e324987194dc8e3a74c86"}, + {file = "frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61d1a5baeaac6c0798ff6edfaeaa00e0e412d49946c53fae8d4b8e8b3566c4ae"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7edf5c043c062462f09b6820de9854bf28cc6cc5b6714b383149745e287181a8"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d50ac7627b3a1bd2dcef6f9da89a772694ec04d9a61b66cf87f7d9446b4a0c31"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ce48b2fece5aeb45265bb7a58259f45027db0abff478e3077e12b05b17fb9da7"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:fe2365ae915a1fafd982c146754e1de6ab3478def8a59c86e1f7242d794f97d5"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:45a6f2fdbd10e074e8814eb98b05292f27bad7d1883afbe009d96abdcf3bc898"}, + {file = "frozenlist-1.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:21884e23cffabb157a9dd7e353779077bf5b8f9a58e9b262c6caad2ef5f80a56"}, + {file = "frozenlist-1.7.0-cp311-cp311-win32.whl", hash = "sha256:284d233a8953d7b24f9159b8a3496fc1ddc00f4db99c324bd5fb5f22d8698ea7"}, + {file = "frozenlist-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:387cbfdcde2f2353f19c2f66bbb52406d06ed77519ac7ee21be0232147c2592d"}, + {file = "frozenlist-1.7.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3dbf9952c4bb0e90e98aec1bd992b3318685005702656bc6f67c1a32b76787f2"}, + {file = "frozenlist-1.7.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1f5906d3359300b8a9bb194239491122e6cf1444c2efb88865426f170c262cdb"}, + {file = "frozenlist-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3dabd5a8f84573c8d10d8859a50ea2dec01eea372031929871368c09fa103478"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa57daa5917f1738064f302bf2626281a1cb01920c32f711fbc7bc36111058a8"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c193dda2b6d49f4c4398962810fa7d7c78f032bf45572b3e04dd5249dff27e08"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe2b675cf0aaa6d61bf8fbffd3c274b3c9b7b1623beb3809df8a81399a4a9c4"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8fc5d5cda37f62b262405cf9652cf0856839c4be8ee41be0afe8858f17f4c94b"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0d5ce521d1dd7d620198829b87ea002956e4319002ef0bc8d3e6d045cb4646e"}, + {file = "frozenlist-1.7.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:488d0a7d6a0008ca0db273c542098a0fa9e7dfaa7e57f70acef43f32b3f69dca"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:15a7eaba63983d22c54d255b854e8108e7e5f3e89f647fc854bd77a237e767df"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:1eaa7e9c6d15df825bf255649e05bd8a74b04a4d2baa1ae46d9c2d00b2ca2cb5"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4389e06714cfa9d47ab87f784a7c5be91d3934cd6e9a7b85beef808297cc025"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:73bd45e1488c40b63fe5a7df892baf9e2a4d4bb6409a2b3b78ac1c6236178e01"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99886d98e1643269760e5fe0df31e5ae7050788dd288947f7f007209b8c33f08"}, + {file = "frozenlist-1.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:290a172aae5a4c278c6da8a96222e6337744cd9c77313efe33d5670b9f65fc43"}, + {file = "frozenlist-1.7.0-cp312-cp312-win32.whl", hash = "sha256:426c7bc70e07cfebc178bc4c2bf2d861d720c4fff172181eeb4a4c41d4ca2ad3"}, + {file = "frozenlist-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:563b72efe5da92e02eb68c59cb37205457c977aa7a449ed1b37e6939e5c47c6a"}, + {file = "frozenlist-1.7.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee80eeda5e2a4e660651370ebffd1286542b67e268aa1ac8d6dbe973120ef7ee"}, + {file = "frozenlist-1.7.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d1a81c85417b914139e3a9b995d4a1c84559afc839a93cf2cb7f15e6e5f6ed2d"}, + {file = "frozenlist-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cbb65198a9132ebc334f237d7b0df163e4de83fb4f2bdfe46c1e654bdb0c5d43"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dab46c723eeb2c255a64f9dc05b8dd601fde66d6b19cdb82b2e09cc6ff8d8b5d"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6aeac207a759d0dedd2e40745575ae32ab30926ff4fa49b1635def65806fddee"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bd8c4e58ad14b4fa7802b8be49d47993182fdd4023393899632c88fd8cd994eb"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04fb24d104f425da3540ed83cbfc31388a586a7696142004c577fa61c6298c3f"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a5c505156368e4ea6b53b5ac23c92d7edc864537ff911d2fb24c140bb175e60"}, + {file = "frozenlist-1.7.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bd7eb96a675f18aa5c553eb7ddc24a43c8c18f22e1f9925528128c052cdbe00"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:05579bf020096fe05a764f1f84cd104a12f78eaab68842d036772dc6d4870b4b"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:376b6222d114e97eeec13d46c486facd41d4f43bab626b7c3f6a8b4e81a5192c"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0aa7e176ebe115379b5b1c95b4096fb1c17cce0847402e227e712c27bdb5a949"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3fbba20e662b9c2130dc771e332a99eff5da078b2b2648153a40669a6d0e36ca"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:f3f4410a0a601d349dd406b5713fec59b4cee7e71678d5b17edda7f4655a940b"}, + {file = "frozenlist-1.7.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e2cdfaaec6a2f9327bf43c933c0319a7c429058e8537c508964a133dffee412e"}, + {file = "frozenlist-1.7.0-cp313-cp313-win32.whl", hash = "sha256:5fc4df05a6591c7768459caba1b342d9ec23fa16195e744939ba5914596ae3e1"}, + {file = "frozenlist-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:52109052b9791a3e6b5d1b65f4b909703984b770694d3eb64fad124c835d7cba"}, + {file = "frozenlist-1.7.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a6f86e4193bb0e235ef6ce3dde5cbabed887e0b11f516ce8a0f4d3b33078ec2d"}, + {file = "frozenlist-1.7.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:82d664628865abeb32d90ae497fb93df398a69bb3434463d172b80fc25b0dd7d"}, + {file = "frozenlist-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:912a7e8375a1c9a68325a902f3953191b7b292aa3c3fb0d71a216221deca460b"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9537c2777167488d539bc5de2ad262efc44388230e5118868e172dd4a552b146"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:f34560fb1b4c3e30ba35fa9a13894ba39e5acfc5f60f57d8accde65f46cc5e74"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:acd03d224b0175f5a850edc104ac19040d35419eddad04e7cf2d5986d98427f1"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2038310bc582f3d6a09b3816ab01737d60bf7b1ec70f5356b09e84fb7408ab1"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8c05e4c8e5f36e5e088caa1bf78a687528f83c043706640a92cb76cd6999384"}, + {file = "frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:765bb588c86e47d0b68f23c1bee323d4b703218037765dcf3f25c838c6fecceb"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:32dc2e08c67d86d0969714dd484fd60ff08ff81d1a1e40a77dd34a387e6ebc0c"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:c0303e597eb5a5321b4de9c68e9845ac8f290d2ab3f3e2c864437d3c5a30cd65"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:a47f2abb4e29b3a8d0b530f7c3598badc6b134562b1a5caee867f7c62fee51e3"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:3d688126c242a6fabbd92e02633414d40f50bb6002fa4cf995a1d18051525657"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:4e7e9652b3d367c7bd449a727dc79d5043f48b88d0cbfd4f9f1060cf2b414104"}, + {file = "frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1a85e345b4c43db8b842cab1feb41be5cc0b10a1830e6295b69d7310f99becaf"}, + {file = "frozenlist-1.7.0-cp313-cp313t-win32.whl", hash = "sha256:3a14027124ddb70dfcee5148979998066897e79f89f64b13328595c4bdf77c81"}, + {file = "frozenlist-1.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:3bf8010d71d4507775f658e9823210b7427be36625b387221642725b515dcf3e"}, + {file = "frozenlist-1.7.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cea3dbd15aea1341ea2de490574a4a37ca080b2ae24e4b4f4b51b9057b4c3630"}, + {file = "frozenlist-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7d536ee086b23fecc36c2073c371572374ff50ef4db515e4e503925361c24f71"}, + {file = "frozenlist-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dfcebf56f703cb2e346315431699f00db126d158455e513bd14089d992101e44"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:974c5336e61d6e7eb1ea5b929cb645e882aadab0095c5a6974a111e6479f8878"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c70db4a0ab5ab20878432c40563573229a7ed9241506181bba12f6b7d0dc41cb"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1137b78384eebaf70560a36b7b229f752fb64d463d38d1304939984d5cb887b6"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e793a9f01b3e8b5c0bc646fb59140ce0efcc580d22a3468d70766091beb81b35"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74739ba8e4e38221d2c5c03d90a7e542cb8ad681915f4ca8f68d04f810ee0a87"}, + {file = "frozenlist-1.7.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e63344c4e929b1a01e29bc184bbb5fd82954869033765bfe8d65d09e336a677"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2ea2a7369eb76de2217a842f22087913cdf75f63cf1307b9024ab82dfb525938"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:836b42f472a0e006e02499cef9352ce8097f33df43baaba3e0a28a964c26c7d2"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e22b9a99741294b2571667c07d9f8cceec07cb92aae5ccda39ea1b6052ed4319"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:9a19e85cc503d958abe5218953df722748d87172f71b73cf3c9257a91b999890"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:f22dac33bb3ee8fe3e013aa7b91dc12f60d61d05b7fe32191ffa84c3aafe77bd"}, + {file = "frozenlist-1.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ccec739a99e4ccf664ea0775149f2749b8a6418eb5b8384b4dc0a7d15d304cb"}, + {file = "frozenlist-1.7.0-cp39-cp39-win32.whl", hash = "sha256:b3950f11058310008a87757f3eee16a8e1ca97979833239439586857bc25482e"}, + {file = "frozenlist-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:43a82fce6769c70f2f5a06248b614a7d268080a9d20f7457ef10ecee5af82b63"}, + {file = "frozenlist-1.7.0-py3-none-any.whl", hash = "sha256:9a5af342e34f7e97caf8c995864c7a396418ae2859cc6fdf1b1073020d516a7e"}, + {file = "frozenlist-1.7.0.tar.gz", hash = "sha256:2e310d81923c2437ea8670467121cc3e9b0f76d3043cc1d2331d56c7fb7a3a8f"}, +] + [[package]] name = "fsspec" version = "2025.5.1" @@ -590,6 +872,9 @@ files = [ {file = "fsspec-2025.5.1.tar.gz", hash = "sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475"}, ] +[package.dependencies] +aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} + [package.extras] abfs = ["adlfs"] adl = ["adlfs"] @@ -785,7 +1070,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -950,6 +1235,71 @@ files = [ {file = "latexcodec-3.0.1.tar.gz", hash = "sha256:e78a6911cd72f9dec35031c6ec23584de6842bfbc4610a9678868d14cdfb0357"}, ] +[[package]] +name = "lightning" +version = "2.5.4" +description = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "lightning-2.5.4-py3-none-any.whl", hash = "sha256:49fdc84df9d809e4f396ab38a5d327df725fe6dd9303473c49e52cba4fbdcc2b"}, + {file = "lightning-2.5.4.tar.gz", hash = "sha256:cec9459a356117f11c501b591fe80f327947614fc345dc6b6c9f8d4d373f214e"}, +] + +[package.dependencies] +fsspec = {version = ">=2022.5.0,<2027.0", extras = ["http"]} +lightning-utilities = ">=0.10.0,<2.0" +packaging = ">=20.0,<27.0" +pytorch-lightning = "*" +PyYAML = ">5.4,<8.0" +torch = ">=2.1.0,<4.0" +torchmetrics = ">0.7.0,<3.0" +tqdm = ">=4.57.0,<6.0" +typing-extensions = ">4.5.0,<6.0" + +[package.extras] +all = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "hydra-core (>=1.2.0,<2.0)", "ipython[all] (<9.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.2.3,<3.0)", "requests (<3.0)", "rich (>=12.3.0,<15.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +data = ["litdata (>=0.2.0rc,<1.0)"] +dev = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "click (==8.1.8) ; python_version < \"3.11\"", "click (==8.2.1) ; python_version > \"3.10\"", "cloudpickle (>=1.3,<4.0)", "coverage (==7.10.5)", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "fastapi", "hydra-core (>=1.2.0,<2.0)", "ipython[all] (<9.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "numpy (>1.20.0,<2.0)", "numpy (>=1.21.0,<2.0)", "omegaconf (>=2.2.3,<3.0)", "onnx (>1.12.0,<2.0)", "onnxruntime (>=1.12.0,<2.0)", "onnxscript (>=0.1.0,<1.0)", "pandas (>2.0,<3.0)", "psutil (<8.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "requests (<3.0)", "rich (>=12.3.0,<15.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.11,<3.0)", "tensorboardX (>=2.2,<3.0)", "tensorboardX (>=2.6,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)", "uvicorn"] +examples = ["ipython[all] (<9.0)", "requests (<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +extra = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "hydra-core (>=1.2.0,<2.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.2.3,<3.0)", "rich (>=12.3.0,<15.0)", "tensorboardX (>=2.2,<3.0)"] +fabric-all = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +fabric-dev = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "click (==8.1.8) ; python_version < \"3.11\"", "click (==8.2.1) ; python_version > \"3.10\"", "coverage (==7.10.5)", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "numpy (>=1.21.0,<2.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "tensorboardX (>=2.6,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +fabric-examples = ["torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +fabric-strategies = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\""] +fabric-test = ["click (==8.1.8) ; python_version < \"3.11\"", "click (==8.2.1) ; python_version > \"3.10\"", "coverage (==7.10.5)", "numpy (>=1.21.0,<2.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "tensorboardX (>=2.6,<3.0)"] +pytorch-all = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "hydra-core (>=1.2.0,<2.0)", "ipython[all] (<9.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.2.3,<3.0)", "requests (<3.0)", "rich (>=12.3.0,<15.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +pytorch-dev = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "cloudpickle (>=1.3,<4.0)", "coverage (==7.10.5)", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "fastapi", "hydra-core (>=1.2.0,<2.0)", "ipython[all] (<9.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "numpy (>1.20.0,<2.0)", "omegaconf (>=2.2.3,<3.0)", "onnx (>1.12.0,<2.0)", "onnxruntime (>=1.12.0,<2.0)", "onnxscript (>=0.1.0,<1.0)", "pandas (>2.0,<3.0)", "psutil (<8.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "requests (<3.0)", "rich (>=12.3.0,<15.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.11,<3.0)", "tensorboardX (>=2.2,<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)", "uvicorn"] +pytorch-examples = ["ipython[all] (<9.0)", "requests (<3.0)", "torchmetrics (>=0.10.0,<2.0)", "torchvision (>=0.16.0,<1.0)"] +pytorch-extra = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "hydra-core (>=1.2.0,<2.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0,<5.0)", "matplotlib (>3.1,<4.0)", "omegaconf (>=2.2.3,<3.0)", "rich (>=12.3.0,<15.0)", "tensorboardX (>=2.2,<3.0)"] +pytorch-strategies = ["deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\""] +pytorch-test = ["cloudpickle (>=1.3,<4.0)", "coverage (==7.10.5)", "fastapi", "numpy (>1.20.0,<2.0)", "onnx (>1.12.0,<2.0)", "onnxruntime (>=1.12.0,<2.0)", "onnxscript (>=0.1.0,<1.0)", "pandas (>2.0,<3.0)", "psutil (<8.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.11,<3.0)", "uvicorn"] +strategies = ["bitsandbytes (>=0.45.2,<1.0) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\""] +test = ["click (==8.1.8) ; python_version < \"3.11\"", "click (==8.2.1) ; python_version > \"3.10\"", "cloudpickle (>=1.3,<4.0)", "coverage (==7.10.5)", "fastapi", "numpy (>1.20.0,<2.0)", "numpy (>=1.21.0,<2.0)", "onnx (>1.12.0,<2.0)", "onnxruntime (>=1.12.0,<2.0)", "onnxscript (>=0.1.0,<1.0)", "pandas (>2.0,<3.0)", "psutil (<8.0)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "scikit-learn (>0.22.1,<2.0)", "tensorboard (>=2.11,<3.0)", "tensorboardX (>=2.6,<3.0)", "uvicorn"] + +[[package]] +name = "lightning-utilities" +version = "0.15.2" +description = "Lightning toolbox for across the our ecosystem." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "lightning_utilities-0.15.2-py3-none-any.whl", hash = "sha256:ad3ab1703775044bbf880dbf7ddaaac899396c96315f3aa1779cec9d618a9841"}, + {file = "lightning_utilities-0.15.2.tar.gz", hash = "sha256:cdf12f530214a63dacefd713f180d1ecf5d165338101617b4742e8f22c032e24"}, +] + +[package.dependencies] +packaging = ">=17.1" +setuptools = "*" +typing_extensions = "*" + +[package.extras] +cli = ["jsonargparse[signatures] (>=4.38.0)", "tomlkit"] +docs = ["requests (>=2.0.0)"] +typing = ["mypy (>=1.0.0)", "types-setuptools"] + [[package]] name = "markdown" version = "3.8" @@ -1125,6 +1475,126 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] +[[package]] +name = "multidict" +version = "6.6.4" +description = "multidict implementation" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "multidict-6.6.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b8aa6f0bd8125ddd04a6593437bad6a7e70f300ff4180a531654aa2ab3f6d58f"}, + {file = "multidict-6.6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9e5853bbd7264baca42ffc53391b490d65fe62849bf2c690fa3f6273dbcd0cb"}, + {file = "multidict-6.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0af5f9dee472371e36d6ae38bde009bd8ce65ac7335f55dcc240379d7bed1495"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:d24f351e4d759f5054b641c81e8291e5d122af0fca5c72454ff77f7cbe492de8"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db6a3810eec08280a172a6cd541ff4a5f6a97b161d93ec94e6c4018917deb6b7"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a1b20a9d56b2d81e2ff52ecc0670d583eaabaa55f402e8d16dd062373dbbe796"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8c9854df0eaa610a23494c32a6f44a3a550fb398b6b51a56e8c6b9b3689578db"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4bb7627fd7a968f41905a4d6343b0d63244a0623f006e9ed989fa2b78f4438a0"}, + {file = "multidict-6.6.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:caebafea30ed049c57c673d0b36238b1748683be2593965614d7b0e99125c877"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ad887a8250eb47d3ab083d2f98db7f48098d13d42eb7a3b67d8a5c795f224ace"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:ed8358ae7d94ffb7c397cecb62cbac9578a83ecefc1eba27b9090ee910e2efb6"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ecab51ad2462197a4c000b6d5701fc8585b80eecb90583635d7e327b7b6923eb"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c5c97aa666cf70e667dfa5af945424ba1329af5dd988a437efeb3a09430389fb"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:9a950b7cf54099c1209f455ac5970b1ea81410f2af60ed9eb3c3f14f0bfcf987"}, + {file = "multidict-6.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:163c7ea522ea9365a8a57832dea7618e6cbdc3cd75f8c627663587459a4e328f"}, + {file = "multidict-6.6.4-cp310-cp310-win32.whl", hash = "sha256:17d2cbbfa6ff20821396b25890f155f40c986f9cfbce5667759696d83504954f"}, + {file = "multidict-6.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:ce9a40fbe52e57e7edf20113a4eaddfacac0561a0879734e636aa6d4bb5e3fb0"}, + {file = "multidict-6.6.4-cp310-cp310-win_arm64.whl", hash = "sha256:01d0959807a451fe9fdd4da3e139cb5b77f7328baf2140feeaf233e1d777b729"}, + {file = "multidict-6.6.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c7a0e9b561e6460484318a7612e725df1145d46b0ef57c6b9866441bf6e27e0c"}, + {file = "multidict-6.6.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6bf2f10f70acc7a2446965ffbc726e5fc0b272c97a90b485857e5c70022213eb"}, + {file = "multidict-6.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66247d72ed62d5dd29752ffc1d3b88f135c6a8de8b5f63b7c14e973ef5bda19e"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:105245cc6b76f51e408451a844a54e6823bbd5a490ebfe5bdfc79798511ceded"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cbbc54e58b34c3bae389ef00046be0961f30fef7cb0dd9c7756aee376a4f7683"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:56c6b3652f945c9bc3ac6c8178cd93132b8d82dd581fcbc3a00676c51302bc1a"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b95494daf857602eccf4c18ca33337dd2be705bccdb6dddbfc9d513e6addb9d9"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e5b1413361cef15340ab9dc61523e653d25723e82d488ef7d60a12878227ed50"}, + {file = "multidict-6.6.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e167bf899c3d724f9662ef00b4f7fef87a19c22b2fead198a6f68b263618df52"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aaea28ba20a9026dfa77f4b80369e51cb767c61e33a2d4043399c67bd95fb7c6"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8c91cdb30809a96d9ecf442ec9bc45e8cfaa0f7f8bdf534e082c2443a196727e"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a0ccbfe93ca114c5d65a2471d52d8829e56d467c97b0e341cf5ee45410033b3"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:55624b3f321d84c403cb7d8e6e982f41ae233d85f85db54ba6286f7295dc8a9c"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:4a1fb393a2c9d202cb766c76208bd7945bc194eba8ac920ce98c6e458f0b524b"}, + {file = "multidict-6.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43868297a5759a845fa3a483fb4392973a95fb1de891605a3728130c52b8f40f"}, + {file = "multidict-6.6.4-cp311-cp311-win32.whl", hash = "sha256:ed3b94c5e362a8a84d69642dbeac615452e8af9b8eb825b7bc9f31a53a1051e2"}, + {file = "multidict-6.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:d8c112f7a90d8ca5d20213aa41eac690bb50a76da153e3afb3886418e61cb22e"}, + {file = "multidict-6.6.4-cp311-cp311-win_arm64.whl", hash = "sha256:3bb0eae408fa1996d87247ca0d6a57b7fc1dcf83e8a5c47ab82c558c250d4adf"}, + {file = "multidict-6.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0ffb87be160942d56d7b87b0fdf098e81ed565add09eaa1294268c7f3caac4c8"}, + {file = "multidict-6.6.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d191de6cbab2aff5de6c5723101705fd044b3e4c7cfd587a1929b5028b9714b3"}, + {file = "multidict-6.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38a0956dd92d918ad5feff3db8fcb4a5eb7dba114da917e1a88475619781b57b"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:6865f6d3b7900ae020b495d599fcf3765653bc927951c1abb959017f81ae8287"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a2088c126b6f72db6c9212ad827d0ba088c01d951cee25e758c450da732c138"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0f37bed7319b848097085d7d48116f545985db988e2256b2e6f00563a3416ee6"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:01368e3c94032ba6ca0b78e7ccb099643466cf24f8dc8eefcfdc0571d56e58f9"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8fe323540c255db0bffee79ad7f048c909f2ab0edb87a597e1c17da6a54e493c"}, + {file = "multidict-6.6.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8eb3025f17b0a4c3cd08cda49acf312a19ad6e8a4edd9dbd591e6506d999402"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bbc14f0365534d35a06970d6a83478b249752e922d662dc24d489af1aa0d1be7"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:75aa52fba2d96bf972e85451b99d8e19cc37ce26fd016f6d4aa60da9ab2b005f"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fefd4a815e362d4f011919d97d7b4a1e566f1dde83dc4ad8cfb5b41de1df68d"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:db9801fe021f59a5b375ab778973127ca0ac52429a26e2fd86aa9508f4d26eb7"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a650629970fa21ac1fb06ba25dabfc5b8a2054fcbf6ae97c758aa956b8dba802"}, + {file = "multidict-6.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:452ff5da78d4720d7516a3a2abd804957532dd69296cb77319c193e3ffb87e24"}, + {file = "multidict-6.6.4-cp312-cp312-win32.whl", hash = "sha256:8c2fcb12136530ed19572bbba61b407f655e3953ba669b96a35036a11a485793"}, + {file = "multidict-6.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:047d9425860a8c9544fed1b9584f0c8bcd31bcde9568b047c5e567a1025ecd6e"}, + {file = "multidict-6.6.4-cp312-cp312-win_arm64.whl", hash = "sha256:14754eb72feaa1e8ae528468f24250dd997b8e2188c3d2f593f9eba259e4b364"}, + {file = "multidict-6.6.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f46a6e8597f9bd71b31cc708195d42b634c8527fecbcf93febf1052cacc1f16e"}, + {file = "multidict-6.6.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:22e38b2bc176c5eb9c0a0e379f9d188ae4cd8b28c0f53b52bce7ab0a9e534657"}, + {file = "multidict-6.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5df8afd26f162da59e218ac0eefaa01b01b2e6cd606cffa46608f699539246da"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:49517449b58d043023720aa58e62b2f74ce9b28f740a0b5d33971149553d72aa"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae9408439537c5afdca05edd128a63f56a62680f4b3c234301055d7a2000220f"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:87a32d20759dc52a9e850fe1061b6e41ab28e2998d44168a8a341b99ded1dba0"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:52e3c8d43cdfff587ceedce9deb25e6ae77daba560b626e97a56ddcad3756879"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ad8850921d3a8d8ff6fbef790e773cecfc260bbfa0566998980d3fa8f520bc4a"}, + {file = "multidict-6.6.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:497a2954adc25c08daff36f795077f63ad33e13f19bfff7736e72c785391534f"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:024ce601f92d780ca1617ad4be5ac15b501cc2414970ffa2bb2bbc2bd5a68fa5"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:a693fc5ed9bdd1c9e898013e0da4dcc640de7963a371c0bd458e50e046bf6438"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:190766dac95aab54cae5b152a56520fd99298f32a1266d66d27fdd1b5ac00f4e"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:34d8f2a5ffdceab9dcd97c7a016deb2308531d5f0fced2bb0c9e1df45b3363d7"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:59e8d40ab1f5a8597abcef00d04845155a5693b5da00d2c93dbe88f2050f2812"}, + {file = "multidict-6.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:467fe64138cfac771f0e949b938c2e1ada2b5af22f39692aa9258715e9ea613a"}, + {file = "multidict-6.6.4-cp313-cp313-win32.whl", hash = "sha256:14616a30fe6d0a48d0a48d1a633ab3b8bec4cf293aac65f32ed116f620adfd69"}, + {file = "multidict-6.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:40cd05eaeb39e2bc8939451f033e57feaa2ac99e07dbca8afe2be450a4a3b6cf"}, + {file = "multidict-6.6.4-cp313-cp313-win_arm64.whl", hash = "sha256:f6eb37d511bfae9e13e82cb4d1af36b91150466f24d9b2b8a9785816deb16605"}, + {file = "multidict-6.6.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:6c84378acd4f37d1b507dfa0d459b449e2321b3ba5f2338f9b085cf7a7ba95eb"}, + {file = "multidict-6.6.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0e0558693063c75f3d952abf645c78f3c5dfdd825a41d8c4d8156fc0b0da6e7e"}, + {file = "multidict-6.6.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3f8e2384cb83ebd23fd07e9eada8ba64afc4c759cd94817433ab8c81ee4b403f"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:f996b87b420995a9174b2a7c1a8daf7db4750be6848b03eb5e639674f7963773"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc356250cffd6e78416cf5b40dc6a74f1edf3be8e834cf8862d9ed5265cf9b0e"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:dadf95aa862714ea468a49ad1e09fe00fcc9ec67d122f6596a8d40caf6cec7d0"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7dd57515bebffd8ebd714d101d4c434063322e4fe24042e90ced41f18b6d3395"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:967af5f238ebc2eb1da4e77af5492219fbd9b4b812347da39a7b5f5c72c0fa45"}, + {file = "multidict-6.6.4-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a4c6875c37aae9794308ec43e3530e4aa0d36579ce38d89979bbf89582002bb"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:7f683a551e92bdb7fac545b9c6f9fa2aebdeefa61d607510b3533286fcab67f5"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:3ba5aaf600edaf2a868a391779f7a85d93bed147854925f34edd24cc70a3e141"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:580b643b7fd2c295d83cad90d78419081f53fd532d1f1eb67ceb7060f61cff0d"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:37b7187197da6af3ee0b044dbc9625afd0c885f2800815b228a0e70f9a7f473d"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e1b93790ed0bc26feb72e2f08299691ceb6da5e9e14a0d13cc74f1869af327a0"}, + {file = "multidict-6.6.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:a506a77ddee1efcca81ecbeae27ade3e09cdf21a8ae854d766c2bb4f14053f92"}, + {file = "multidict-6.6.4-cp313-cp313t-win32.whl", hash = "sha256:f93b2b2279883d1d0a9e1bd01f312d6fc315c5e4c1f09e112e4736e2f650bc4e"}, + {file = "multidict-6.6.4-cp313-cp313t-win_amd64.whl", hash = "sha256:6d46a180acdf6e87cc41dc15d8f5c2986e1e8739dc25dbb7dac826731ef381a4"}, + {file = "multidict-6.6.4-cp313-cp313t-win_arm64.whl", hash = "sha256:756989334015e3335d087a27331659820d53ba432befdef6a718398b0a8493ad"}, + {file = "multidict-6.6.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:af7618b591bae552b40dbb6f93f5518328a949dac626ee75927bba1ecdeea9f4"}, + {file = "multidict-6.6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b6819f83aef06f560cb15482d619d0e623ce9bf155115150a85ab11b8342a665"}, + {file = "multidict-6.6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4d09384e75788861e046330308e7af54dd306aaf20eb760eb1d0de26b2bea2cb"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:a59c63061f1a07b861c004e53869eb1211ffd1a4acbca330e3322efa6dd02978"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:350f6b0fe1ced61e778037fdc7613f4051c8baf64b1ee19371b42a3acdb016a0"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0c5cbac6b55ad69cb6aa17ee9343dfbba903118fd530348c330211dc7aa756d1"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:630f70c32b8066ddfd920350bc236225814ad94dfa493fe1910ee17fe4365cbb"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f8d4916a81697faec6cb724a273bd5457e4c6c43d82b29f9dc02c5542fd21fc9"}, + {file = "multidict-6.6.4-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e42332cf8276bb7645d310cdecca93a16920256a5b01bebf747365f86a1675b"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f3be27440f7644ab9a13a6fc86f09cdd90b347c3c5e30c6d6d860de822d7cb53"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:21f216669109e02ef3e2415ede07f4f8987f00de8cdfa0cc0b3440d42534f9f0"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:d9890d68c45d1aeac5178ded1d1cccf3bc8d7accf1f976f79bf63099fb16e4bd"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:edfdcae97cdc5d1a89477c436b61f472c4d40971774ac4729c613b4b133163cb"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0b2e886624be5773e69cf32bcb8534aecdeb38943520b240fed3d5596a430f2f"}, + {file = "multidict-6.6.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:be5bf4b3224948032a845d12ab0f69f208293742df96dc14c4ff9b09e508fc17"}, + {file = "multidict-6.6.4-cp39-cp39-win32.whl", hash = "sha256:10a68a9191f284fe9d501fef4efe93226e74df92ce7a24e301371293bd4918ae"}, + {file = "multidict-6.6.4-cp39-cp39-win_amd64.whl", hash = "sha256:ee25f82f53262f9ac93bd7e58e47ea1bdcc3393cef815847e397cba17e284210"}, + {file = "multidict-6.6.4-cp39-cp39-win_arm64.whl", hash = "sha256:f9867e55590e0855bcec60d4f9a092b69476db64573c9fe17e92b0c50614c16a"}, + {file = "multidict-6.6.4-py3-none-any.whl", hash = "sha256:27d8f8e125c07cb954e54d75d04905a9bba8a439c1d84aca94949d4d03d8601c"}, + {file = "multidict-6.6.4.tar.gz", hash = "sha256:d2d4e4787672911b48350df02ed3fa3fffdc2f2e8ca06dd6afdf34189b76a9dd"}, +] + [[package]] name = "mypy-extensions" version = "1.1.0" @@ -1742,6 +2212,114 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "propcache" +version = "0.3.2" +description = "Accelerated property cache" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "propcache-0.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:22d9962a358aedbb7a2e36187ff273adeaab9743373a272976d2e348d08c7770"}, + {file = "propcache-0.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d0fda578d1dc3f77b6b5a5dce3b9ad69a8250a891760a548df850a5e8da87f3"}, + {file = "propcache-0.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3def3da3ac3ce41562d85db655d18ebac740cb3fa4367f11a52b3da9d03a5cc3"}, + {file = "propcache-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bec58347a5a6cebf239daba9bda37dffec5b8d2ce004d9fe4edef3d2815137e"}, + {file = "propcache-0.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55ffda449a507e9fbd4aca1a7d9aa6753b07d6166140e5a18d2ac9bc49eac220"}, + {file = "propcache-0.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64a67fb39229a8a8491dd42f864e5e263155e729c2e7ff723d6e25f596b1e8cb"}, + {file = "propcache-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9da1cf97b92b51253d5b68cf5a2b9e0dafca095e36b7f2da335e27dc6172a614"}, + {file = "propcache-0.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f559e127134b07425134b4065be45b166183fdcb433cb6c24c8e4149056ad50"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aff2e4e06435d61f11a428360a932138d0ec288b0a31dd9bd78d200bd4a2b339"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:4927842833830942a5d0a56e6f4839bc484785b8e1ce8d287359794818633ba0"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6107ddd08b02654a30fb8ad7a132021759d750a82578b94cd55ee2772b6ebea2"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:70bd8b9cd6b519e12859c99f3fc9a93f375ebd22a50296c3a295028bea73b9e7"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2183111651d710d3097338dd1893fcf09c9f54e27ff1a8795495a16a469cc90b"}, + {file = "propcache-0.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:fb075ad271405dcad8e2a7ffc9a750a3bf70e533bd86e89f0603e607b93aa64c"}, + {file = "propcache-0.3.2-cp310-cp310-win32.whl", hash = "sha256:404d70768080d3d3bdb41d0771037da19d8340d50b08e104ca0e7f9ce55fce70"}, + {file = "propcache-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:7435d766f978b4ede777002e6b3b6641dd229cd1da8d3d3106a45770365f9ad9"}, + {file = "propcache-0.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0b8d2f607bd8f80ddc04088bc2a037fdd17884a6fcadc47a96e334d72f3717be"}, + {file = "propcache-0.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06766d8f34733416e2e34f46fea488ad5d60726bb9481d3cddf89a6fa2d9603f"}, + {file = "propcache-0.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2dc1f4a1df4fecf4e6f68013575ff4af84ef6f478fe5344317a65d38a8e6dc9"}, + {file = "propcache-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be29c4f4810c5789cf10ddf6af80b041c724e629fa51e308a7a0fb19ed1ef7bf"}, + {file = "propcache-0.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59d61f6970ecbd8ff2e9360304d5c8876a6abd4530cb752c06586849ac8a9dc9"}, + {file = "propcache-0.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:62180e0b8dbb6b004baec00a7983e4cc52f5ada9cd11f48c3528d8cfa7b96a66"}, + {file = "propcache-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c144ca294a204c470f18cf4c9d78887810d04a3e2fbb30eea903575a779159df"}, + {file = "propcache-0.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5c2a784234c28854878d68978265617aa6dc0780e53d44b4d67f3651a17a9a2"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5745bc7acdafa978ca1642891b82c19238eadc78ba2aaa293c6863b304e552d7"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:c0075bf773d66fa8c9d41f66cc132ecc75e5bb9dd7cce3cfd14adc5ca184cb95"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5f57aa0847730daceff0497f417c9de353c575d8da3579162cc74ac294c5369e"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:eef914c014bf72d18efb55619447e0aecd5fb7c2e3fa7441e2e5d6099bddff7e"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:2a4092e8549031e82facf3decdbc0883755d5bbcc62d3aea9d9e185549936dcf"}, + {file = "propcache-0.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:85871b050f174bc0bfb437efbdb68aaf860611953ed12418e4361bc9c392749e"}, + {file = "propcache-0.3.2-cp311-cp311-win32.whl", hash = "sha256:36c8d9b673ec57900c3554264e630d45980fd302458e4ac801802a7fd2ef7897"}, + {file = "propcache-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e53af8cb6a781b02d2ea079b5b853ba9430fcbe18a8e3ce647d5982a3ff69f39"}, + {file = "propcache-0.3.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:8de106b6c84506b31c27168582cd3cb3000a6412c16df14a8628e5871ff83c10"}, + {file = "propcache-0.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:28710b0d3975117239c76600ea351934ac7b5ff56e60953474342608dbbb6154"}, + {file = "propcache-0.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce26862344bdf836650ed2487c3d724b00fbfec4233a1013f597b78c1cb73615"}, + {file = "propcache-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bca54bd347a253af2cf4544bbec232ab982f4868de0dd684246b67a51bc6b1db"}, + {file = "propcache-0.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55780d5e9a2ddc59711d727226bb1ba83a22dd32f64ee15594b9392b1f544eb1"}, + {file = "propcache-0.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:035e631be25d6975ed87ab23153db6a73426a48db688070d925aa27e996fe93c"}, + {file = "propcache-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee6f22b6eaa39297c751d0e80c0d3a454f112f5c6481214fcf4c092074cecd67"}, + {file = "propcache-0.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ca3aee1aa955438c4dba34fc20a9f390e4c79967257d830f137bd5a8a32ed3b"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7a4f30862869fa2b68380d677cc1c5fcf1e0f2b9ea0cf665812895c75d0ca3b8"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b77ec3c257d7816d9f3700013639db7491a434644c906a2578a11daf13176251"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:cab90ac9d3f14b2d5050928483d3d3b8fb6b4018893fc75710e6aa361ecb2474"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0b504d29f3c47cf6b9e936c1852246c83d450e8e063d50562115a6be6d3a2535"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:ce2ac2675a6aa41ddb2a0c9cbff53780a617ac3d43e620f8fd77ba1c84dcfc06"}, + {file = "propcache-0.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b4239611205294cc433845b914131b2a1f03500ff3c1ed093ed216b82621e1"}, + {file = "propcache-0.3.2-cp312-cp312-win32.whl", hash = "sha256:df4a81b9b53449ebc90cc4deefb052c1dd934ba85012aa912c7ea7b7e38b60c1"}, + {file = "propcache-0.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7046e79b989d7fe457bb755844019e10f693752d169076138abf17f31380800c"}, + {file = "propcache-0.3.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ca592ed634a73ca002967458187109265e980422116c0a107cf93d81f95af945"}, + {file = "propcache-0.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9ecb0aad4020e275652ba3975740f241bd12a61f1a784df044cf7477a02bc252"}, + {file = "propcache-0.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7f08f1cc28bd2eade7a8a3d2954ccc673bb02062e3e7da09bc75d843386b342f"}, + {file = "propcache-0.3.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1a342c834734edb4be5ecb1e9fb48cb64b1e2320fccbd8c54bf8da8f2a84c33"}, + {file = "propcache-0.3.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a544caaae1ac73f1fecfae70ded3e93728831affebd017d53449e3ac052ac1e"}, + {file = "propcache-0.3.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310d11aa44635298397db47a3ebce7db99a4cc4b9bbdfcf6c98a60c8d5261cf1"}, + {file = "propcache-0.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c1396592321ac83157ac03a2023aa6cc4a3cc3cfdecb71090054c09e5a7cce3"}, + {file = "propcache-0.3.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cabf5b5902272565e78197edb682017d21cf3b550ba0460ee473753f28d23c1"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0a2f2235ac46a7aa25bdeb03a9e7060f6ecbd213b1f9101c43b3090ffb971ef6"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:92b69e12e34869a6970fd2f3da91669899994b47c98f5d430b781c26f1d9f387"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:54e02207c79968ebbdffc169591009f4474dde3b4679e16634d34c9363ff56b4"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4adfb44cb588001f68c5466579d3f1157ca07f7504fc91ec87862e2b8e556b88"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:fd3e6019dc1261cd0291ee8919dd91fbab7b169bb76aeef6c716833a3f65d206"}, + {file = "propcache-0.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4c181cad81158d71c41a2bce88edce078458e2dd5ffee7eddd6b05da85079f43"}, + {file = "propcache-0.3.2-cp313-cp313-win32.whl", hash = "sha256:8a08154613f2249519e549de2330cf8e2071c2887309a7b07fb56098f5170a02"}, + {file = "propcache-0.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:e41671f1594fc4ab0a6dec1351864713cb3a279910ae8b58f884a88a0a632c05"}, + {file = "propcache-0.3.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:9a3cf035bbaf035f109987d9d55dc90e4b0e36e04bbbb95af3055ef17194057b"}, + {file = "propcache-0.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:156c03d07dc1323d8dacaa221fbe028c5c70d16709cdd63502778e6c3ccca1b0"}, + {file = "propcache-0.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74413c0ba02ba86f55cf60d18daab219f7e531620c15f1e23d95563f505efe7e"}, + {file = "propcache-0.3.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f066b437bb3fa39c58ff97ab2ca351db465157d68ed0440abecb21715eb24b28"}, + {file = "propcache-0.3.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1304b085c83067914721e7e9d9917d41ad87696bf70f0bc7dee450e9c71ad0a"}, + {file = "propcache-0.3.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ab50cef01b372763a13333b4e54021bdcb291fc9a8e2ccb9c2df98be51bcde6c"}, + {file = "propcache-0.3.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fad3b2a085ec259ad2c2842666b2a0a49dea8463579c606426128925af1ed725"}, + {file = "propcache-0.3.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:261fa020c1c14deafd54c76b014956e2f86991af198c51139faf41c4d5e83892"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:46d7f8aa79c927e5f987ee3a80205c987717d3659f035c85cf0c3680526bdb44"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:6d8f3f0eebf73e3c0ff0e7853f68be638b4043c65a70517bb575eff54edd8dbe"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:03c89c1b14a5452cf15403e291c0ccd7751d5b9736ecb2c5bab977ad6c5bcd81"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:0cc17efde71e12bbaad086d679ce575268d70bc123a5a71ea7ad76f70ba30bba"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:acdf05d00696bc0447e278bb53cb04ca72354e562cf88ea6f9107df8e7fd9770"}, + {file = "propcache-0.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4445542398bd0b5d32df908031cb1b30d43ac848e20470a878b770ec2dcc6330"}, + {file = "propcache-0.3.2-cp313-cp313t-win32.whl", hash = "sha256:f86e5d7cd03afb3a1db8e9f9f6eff15794e79e791350ac48a8c924e6f439f394"}, + {file = "propcache-0.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9704bedf6e7cbe3c65eca4379a9b53ee6a83749f047808cbb5044d40d7d72198"}, + {file = "propcache-0.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a7fad897f14d92086d6b03fdd2eb844777b0c4d7ec5e3bac0fbae2ab0602bbe5"}, + {file = "propcache-0.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1f43837d4ca000243fd7fd6301947d7cb93360d03cd08369969450cc6b2ce3b4"}, + {file = "propcache-0.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:261df2e9474a5949c46e962065d88eb9b96ce0f2bd30e9d3136bcde84befd8f2"}, + {file = "propcache-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e514326b79e51f0a177daab1052bc164d9d9e54133797a3a58d24c9c87a3fe6d"}, + {file = "propcache-0.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d4a996adb6904f85894570301939afeee65f072b4fd265ed7e569e8d9058e4ec"}, + {file = "propcache-0.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76cace5d6b2a54e55b137669b30f31aa15977eeed390c7cbfb1dafa8dfe9a701"}, + {file = "propcache-0.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31248e44b81d59d6addbb182c4720f90b44e1efdc19f58112a3c3a1615fb47ef"}, + {file = "propcache-0.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abb7fa19dbf88d3857363e0493b999b8011eea856b846305d8c0512dfdf8fbb1"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d81ac3ae39d38588ad0549e321e6f773a4e7cc68e7751524a22885d5bbadf886"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:cc2782eb0f7a16462285b6f8394bbbd0e1ee5f928034e941ffc444012224171b"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:db429c19a6c7e8a1c320e6a13c99799450f411b02251fb1b75e6217cf4a14fcb"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:21d8759141a9e00a681d35a1f160892a36fb6caa715ba0b832f7747da48fb6ea"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:2ca6d378f09adb13837614ad2754fa8afaee330254f404299611bce41a8438cb"}, + {file = "propcache-0.3.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:34a624af06c048946709f4278b4176470073deda88d91342665d95f7c6270fbe"}, + {file = "propcache-0.3.2-cp39-cp39-win32.whl", hash = "sha256:4ba3fef1c30f306b1c274ce0b8baaa2c3cdd91f645c48f06394068f37d3837a1"}, + {file = "propcache-0.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:7a2368eed65fc69a7a7a40b27f22e85e7627b74216f0846b04ba5c116e191ec9"}, + {file = "propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f"}, + {file = "propcache-0.3.2.tar.gz", hash = "sha256:20d7d62e4e7ef05f221e0db2856b979540686342e7dd9973b815599c7057e168"}, +] + [[package]] name = "protobuf" version = "6.31.1" @@ -1973,6 +2551,37 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytorch-lightning" +version = "2.5.4" +description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pytorch_lightning-2.5.4-py3-none-any.whl", hash = "sha256:cbdc45c1fbd6dbaf856990c618de994c3bcca7fea599b8471ac9fa59df598b38"}, + {file = "pytorch_lightning-2.5.4.tar.gz", hash = "sha256:159b63f3dcd72da50566dc4b599adb4adcd07503193ade4fa518e51ccd0751ef"}, +] + +[package.dependencies] +fsspec = {version = ">=2022.5.0", extras = ["http"]} +lightning-utilities = ">=0.10.0" +packaging = ">=20.0" +PyYAML = ">5.4" +torch = ">=2.1.0" +torchmetrics = ">0.7.0" +tqdm = ">=4.57.0" +typing-extensions = ">4.5.0" + +[package.extras] +all = ["bitsandbytes (>=0.45.2) ; platform_system != \"Darwin\"", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "hydra-core (>=1.2.0)", "ipython[all] (<8.19.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0)", "matplotlib (>3.1)", "omegaconf (>=2.2.3)", "requests (<2.33.0)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.16.0)"] +deepspeed = ["deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\""] +dev = ["bitsandbytes (>=0.45.2) ; platform_system != \"Darwin\"", "cloudpickle (>=1.3)", "coverage (==7.10.5)", "deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\"", "fastapi", "hydra-core (>=1.2.0)", "ipython[all] (<8.19.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0)", "matplotlib (>3.1)", "numpy (>1.20.0)", "omegaconf (>=2.2.3)", "onnx (>1.12.0)", "onnxruntime (>=1.12.0)", "onnxscript (>=0.1.0)", "pandas (>2.0)", "psutil (<7.0.1)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "requests (<2.33.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.11)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.16.0)", "uvicorn"] +examples = ["ipython[all] (<8.19.0)", "requests (<2.33.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.16.0)"] +extra = ["bitsandbytes (>=0.45.2) ; platform_system != \"Darwin\"", "hydra-core (>=1.2.0)", "jsonargparse[jsonnet,signatures] (>=4.39.0)", "matplotlib (>3.1)", "omegaconf (>=2.2.3)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] +strategies = ["deepspeed (>=0.14.1,<=0.15.0) ; platform_system != \"Windows\" and platform_system != \"Darwin\""] +test = ["cloudpickle (>=1.3)", "coverage (==7.10.5)", "fastapi", "numpy (>1.20.0)", "onnx (>1.12.0)", "onnxruntime (>=1.12.0)", "onnxscript (>=0.1.0)", "pandas (>2.0)", "psutil (<7.0.1)", "pytest (==8.4.1)", "pytest-cov (==6.2.1)", "pytest-random-order (==1.2.0)", "pytest-rerunfailures (==15.1)", "pytest-timeout (==2.4.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.11)", "uvicorn"] + [[package]] name = "pytz" version = "2025.2" @@ -2608,6 +3217,37 @@ typing-extensions = ">=4.10.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.13.0)"] +[[package]] +name = "torchmetrics" +version = "1.8.1" +description = "PyTorch native Metrics" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "torchmetrics-1.8.1-py3-none-any.whl", hash = "sha256:2437501351e0da3d294c71210ce8139b9c762b5e20604f7a051a725443db8f4b"}, + {file = "torchmetrics-1.8.1.tar.gz", hash = "sha256:04ca021105871637c5d34d0a286b3ab665a1e3d2b395e561f14188a96e862fdb"}, +] + +[package.dependencies] +lightning-utilities = ">=0.8.0" +numpy = ">1.20.0" +packaging = ">17.1" +torch = ">=2.0.0" + +[package.extras] +all = ["SciencePlots (>=2.0.0)", "einops (>=0.7.0)", "einops (>=0.7.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.17.1)", "nltk (>3.8.1)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "timm (>=0.9.0)", "torch (==2.7.1)", "torch-fidelity (<=0.4.0)", "torch_linear_assignment (>=0.0.2)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>=4.43.0)", "transformers (>=4.43.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate", "vmaf-torch (>=1.1.0)"] +audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"] +clustering = ["torch_linear_assignment (>=0.0.2)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"] +dev = ["PyTDC (==0.4.1) ; python_version < \"3.10\" or platform_system == \"Windows\" and python_version < \"3.12\"", "SciencePlots (>=2.0.0)", "aeon (>=1.0.0) ; python_version > \"3.10\"", "bert_score (==0.3.13)", "dists-pytorch (==0.1)", "dython (==0.7.9)", "einops (>=0.7.0)", "einops (>=0.7.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.35)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0) ; python_version < \"3.12\"", "mecab-ko-dic (>=1.0.0) ; python_version < \"3.12\"", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.4.0)", "mypy (==1.17.1)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.4.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "properscoring (==0.1)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "timm (>=0.9.0)", "torch (==2.7.1)", "torch-fidelity (<=0.4.0)", "torch_complex (<0.5.0)", "torch_linear_assignment (>=0.0.2)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>=4.43.0)", "transformers (>=4.43.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate", "vmaf-torch (>=1.1.0)"] +image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"] +multimodal = ["einops (>=0.7.0)", "piq (<=0.8.0)", "timm (>=0.9.0)", "transformers (>=4.43.0)"] +text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>=4.43.0)"] +typing = ["mypy (==1.17.1)", "torch (==2.7.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +video = ["einops (>=0.7.0)", "vmaf-torch (>=1.1.0)"] +visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"] + [[package]] name = "tqdm" version = "4.67.1" @@ -2785,7 +3425,126 @@ einops = ["einops"] numba = ["numba (>=0.55)"] test = ["hypothesis", "packaging", "preliz (>=0.19)", "pytest", "pytest-cov", "scipy (>=1.15)"] +[[package]] +name = "yarl" +version = "1.20.1" +description = "Yet another URL library" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "yarl-1.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6032e6da6abd41e4acda34d75a816012717000fa6839f37124a47fcefc49bec4"}, + {file = "yarl-1.20.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2c7b34d804b8cf9b214f05015c4fee2ebe7ed05cf581e7192c06555c71f4446a"}, + {file = "yarl-1.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c869f2651cc77465f6cd01d938d91a11d9ea5d798738c1dc077f3de0b5e5fed"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62915e6688eb4d180d93840cda4110995ad50c459bf931b8b3775b37c264af1e"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:41ebd28167bc6af8abb97fec1a399f412eec5fd61a3ccbe2305a18b84fb4ca73"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:21242b4288a6d56f04ea193adde174b7e347ac46ce6bc84989ff7c1b1ecea84e"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bea21cdae6c7eb02ba02a475f37463abfe0a01f5d7200121b03e605d6a0439f8"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f8a891e4a22a89f5dde7862994485e19db246b70bb288d3ce73a34422e55b23"}, + {file = "yarl-1.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd803820d44c8853a109a34e3660e5a61beae12970da479cf44aa2954019bf70"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b982fa7f74c80d5c0c7b5b38f908971e513380a10fecea528091405f519b9ebb"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:33f29ecfe0330c570d997bcf1afd304377f2e48f61447f37e846a6058a4d33b2"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:835ab2cfc74d5eb4a6a528c57f05688099da41cf4957cf08cad38647e4a83b30"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:46b5e0ccf1943a9a6e766b2c2b8c732c55b34e28be57d8daa2b3c1d1d4009309"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:df47c55f7d74127d1b11251fe6397d84afdde0d53b90bedb46a23c0e534f9d24"}, + {file = "yarl-1.20.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76d12524d05841276b0e22573f28d5fbcb67589836772ae9244d90dd7d66aa13"}, + {file = "yarl-1.20.1-cp310-cp310-win32.whl", hash = "sha256:6c4fbf6b02d70e512d7ade4b1f998f237137f1417ab07ec06358ea04f69134f8"}, + {file = "yarl-1.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:aef6c4d69554d44b7f9d923245f8ad9a707d971e6209d51279196d8e8fe1ae16"}, + {file = "yarl-1.20.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:47ee6188fea634bdfaeb2cc420f5b3b17332e6225ce88149a17c413c77ff269e"}, + {file = "yarl-1.20.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0f6500f69e8402d513e5eedb77a4e1818691e8f45e6b687147963514d84b44b"}, + {file = "yarl-1.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7a8900a42fcdaad568de58887c7b2f602962356908eedb7628eaf6021a6e435b"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bad6d131fda8ef508b36be3ece16d0902e80b88ea7200f030a0f6c11d9e508d4"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:df018d92fe22aaebb679a7f89fe0c0f368ec497e3dda6cb81a567610f04501f1"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f969afbb0a9b63c18d0feecf0db09d164b7a44a053e78a7d05f5df163e43833"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:812303eb4aa98e302886ccda58d6b099e3576b1b9276161469c25803a8db277d"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98c4a7d166635147924aa0bf9bfe8d8abad6fffa6102de9c99ea04a1376f91e8"}, + {file = "yarl-1.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12e768f966538e81e6e7550f9086a6236b16e26cd964cf4df35349970f3551cf"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fe41919b9d899661c5c28a8b4b0acf704510b88f27f0934ac7a7bebdd8938d5e"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8601bc010d1d7780592f3fc1bdc6c72e2b6466ea34569778422943e1a1f3c389"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:daadbdc1f2a9033a2399c42646fbd46da7992e868a5fe9513860122d7fe7a73f"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:03aa1e041727cb438ca762628109ef1333498b122e4c76dd858d186a37cec845"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:642980ef5e0fa1de5fa96d905c7e00cb2c47cb468bfcac5a18c58e27dbf8d8d1"}, + {file = "yarl-1.20.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:86971e2795584fe8c002356d3b97ef6c61862720eeff03db2a7c86b678d85b3e"}, + {file = "yarl-1.20.1-cp311-cp311-win32.whl", hash = "sha256:597f40615b8d25812f14562699e287f0dcc035d25eb74da72cae043bb884d773"}, + {file = "yarl-1.20.1-cp311-cp311-win_amd64.whl", hash = "sha256:26ef53a9e726e61e9cd1cda6b478f17e350fb5800b4bd1cd9fe81c4d91cfeb2e"}, + {file = "yarl-1.20.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdcc4cd244e58593a4379fe60fdee5ac0331f8eb70320a24d591a3be197b94a9"}, + {file = "yarl-1.20.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b29a2c385a5f5b9c7d9347e5812b6f7ab267193c62d282a540b4fc528c8a9d2a"}, + {file = "yarl-1.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1112ae8154186dfe2de4732197f59c05a83dc814849a5ced892b708033f40dc2"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90bbd29c4fe234233f7fa2b9b121fb63c321830e5d05b45153a2ca68f7d310ee"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:680e19c7ce3710ac4cd964e90dad99bf9b5029372ba0c7cbfcd55e54d90ea819"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a979218c1fdb4246a05efc2cc23859d47c89af463a90b99b7c56094daf25a16"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255b468adf57b4a7b65d8aad5b5138dce6a0752c139965711bdcb81bc370e1b6"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a97d67108e79cfe22e2b430d80d7571ae57d19f17cda8bb967057ca8a7bf5bfd"}, + {file = "yarl-1.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8570d998db4ddbfb9a590b185a0a33dbf8aafb831d07a5257b4ec9948df9cb0a"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97c75596019baae7c71ccf1d8cc4738bc08134060d0adfcbe5642f778d1dca38"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:1c48912653e63aef91ff988c5432832692ac5a1d8f0fb8a33091520b5bbe19ef"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4c3ae28f3ae1563c50f3d37f064ddb1511ecc1d5584e88c6b7c63cf7702a6d5f"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c5e9642f27036283550f5f57dc6156c51084b458570b9d0d96100c8bebb186a8"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2c26b0c49220d5799f7b22c6838409ee9bc58ee5c95361a4d7831f03cc225b5a"}, + {file = "yarl-1.20.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564ab3d517e3d01c408c67f2e5247aad4019dcf1969982aba3974b4093279004"}, + {file = "yarl-1.20.1-cp312-cp312-win32.whl", hash = "sha256:daea0d313868da1cf2fac6b2d3a25c6e3a9e879483244be38c8e6a41f1d876a5"}, + {file = "yarl-1.20.1-cp312-cp312-win_amd64.whl", hash = "sha256:48ea7d7f9be0487339828a4de0360d7ce0efc06524a48e1810f945c45b813698"}, + {file = "yarl-1.20.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0b5ff0fbb7c9f1b1b5ab53330acbfc5247893069e7716840c8e7d5bb7355038a"}, + {file = "yarl-1.20.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:14f326acd845c2b2e2eb38fb1346c94f7f3b01a4f5c788f8144f9b630bfff9a3"}, + {file = "yarl-1.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f60e4ad5db23f0b96e49c018596707c3ae89f5d0bd97f0ad3684bcbad899f1e7"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49bdd1b8e00ce57e68ba51916e4bb04461746e794e7c4d4bbc42ba2f18297691"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:66252d780b45189975abfed839616e8fd2dbacbdc262105ad7742c6ae58f3e31"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59174e7332f5d153d8f7452a102b103e2e74035ad085f404df2e40e663a22b28"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e3968ec7d92a0c0f9ac34d5ecfd03869ec0cab0697c91a45db3fbbd95fe1b653"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a4fbb50e14396ba3d375f68bfe02215d8e7bc3ec49da8341fe3157f59d2ff5"}, + {file = "yarl-1.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11a62c839c3a8eac2410e951301309426f368388ff2f33799052787035793b02"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:041eaa14f73ff5a8986b4388ac6bb43a77f2ea09bf1913df7a35d4646db69e53"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:377fae2fef158e8fd9d60b4c8751387b8d1fb121d3d0b8e9b0be07d1b41e83dc"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1c92f4390e407513f619d49319023664643d3339bd5e5a56a3bebe01bc67ec04"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d25ddcf954df1754ab0f86bb696af765c5bfaba39b74095f27eececa049ef9a4"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:909313577e9619dcff8c31a0ea2aa0a2a828341d92673015456b3ae492e7317b"}, + {file = "yarl-1.20.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:793fd0580cb9664548c6b83c63b43c477212c0260891ddf86809e1c06c8b08f1"}, + {file = "yarl-1.20.1-cp313-cp313-win32.whl", hash = "sha256:468f6e40285de5a5b3c44981ca3a319a4b208ccc07d526b20b12aeedcfa654b7"}, + {file = "yarl-1.20.1-cp313-cp313-win_amd64.whl", hash = "sha256:495b4ef2fea40596bfc0affe3837411d6aa3371abcf31aac0ccc4bdd64d4ef5c"}, + {file = "yarl-1.20.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:f60233b98423aab21d249a30eb27c389c14929f47be8430efa7dbd91493a729d"}, + {file = "yarl-1.20.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6f3eff4cc3f03d650d8755c6eefc844edde99d641d0dcf4da3ab27141a5f8ddf"}, + {file = "yarl-1.20.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:69ff8439d8ba832d6bed88af2c2b3445977eba9a4588b787b32945871c2444e3"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cf34efa60eb81dd2645a2e13e00bb98b76c35ab5061a3989c7a70f78c85006d"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:8e0fe9364ad0fddab2688ce72cb7a8e61ea42eff3c7caeeb83874a5d479c896c"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f64fbf81878ba914562c672024089e3401974a39767747691c65080a67b18c1"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6342d643bf9a1de97e512e45e4b9560a043347e779a173250824f8b254bd5ce"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56dac5f452ed25eef0f6e3c6a066c6ab68971d96a9fb441791cad0efba6140d3"}, + {file = "yarl-1.20.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7d7f497126d65e2cad8dc5f97d34c27b19199b6414a40cb36b52f41b79014be"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:67e708dfb8e78d8a19169818eeb5c7a80717562de9051bf2413aca8e3696bf16"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:595c07bc79af2494365cc96ddeb772f76272364ef7c80fb892ef9d0649586513"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7bdd2f80f4a7df852ab9ab49484a4dee8030023aa536df41f2d922fd57bf023f"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:c03bfebc4ae8d862f853a9757199677ab74ec25424d0ebd68a0027e9c639a390"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:344d1103e9c1523f32a5ed704d576172d2cabed3122ea90b1d4e11fe17c66458"}, + {file = "yarl-1.20.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:88cab98aa4e13e1ade8c141daeedd300a4603b7132819c484841bb7af3edce9e"}, + {file = "yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d"}, + {file = "yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f"}, + {file = "yarl-1.20.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e42ba79e2efb6845ebab49c7bf20306c4edf74a0b20fc6b2ccdd1a219d12fad3"}, + {file = "yarl-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:41493b9b7c312ac448b7f0a42a089dffe1d6e6e981a2d76205801a023ed26a2b"}, + {file = "yarl-1.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f5a5928ff5eb13408c62a968ac90d43f8322fd56d87008b8f9dabf3c0f6ee983"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30c41ad5d717b3961b2dd785593b67d386b73feca30522048d37298fee981805"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:59febc3969b0781682b469d4aca1a5cab7505a4f7b85acf6db01fa500fa3f6ba"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d2b6fb3622b7e5bf7a6e5b679a69326b4279e805ed1699d749739a61d242449e"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:749d73611db8d26a6281086f859ea7ec08f9c4c56cec864e52028c8b328db723"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9427925776096e664c39e131447aa20ec738bdd77c049c48ea5200db2237e000"}, + {file = "yarl-1.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff70f32aa316393eaf8222d518ce9118148eddb8a53073c2403863b41033eed5"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c7ddf7a09f38667aea38801da8b8d6bfe81df767d9dfc8c88eb45827b195cd1c"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:57edc88517d7fc62b174fcfb2e939fbc486a68315d648d7e74d07fac42cec240"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:dab096ce479d5894d62c26ff4f699ec9072269d514b4edd630a393223f45a0ee"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:14a85f3bd2d7bb255be7183e5d7d6e70add151a98edf56a770d6140f5d5f4010"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:2c89b5c792685dd9cd3fa9761c1b9f46fc240c2a3265483acc1565769996a3f8"}, + {file = "yarl-1.20.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:69e9b141de5511021942a6866990aea6d111c9042235de90e08f94cf972ca03d"}, + {file = "yarl-1.20.1-cp39-cp39-win32.whl", hash = "sha256:b5f307337819cdfdbb40193cad84978a029f847b0a357fbe49f712063cfc4f06"}, + {file = "yarl-1.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:eae7bfe2069f9c1c5b05fc7fe5d612e5bbc089a39309904ee8b829e322dcad00"}, + {file = "yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77"}, + {file = "yarl-1.20.1.tar.gz", hash = "sha256:d017a4997ee50c91fd5466cef416231bb82177b93b029906cefc542ce14c35ac"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" +propcache = ">=0.2.1" + [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "b81fb19fcfdd825f41e7c53cbe04222bbdcc9e697676b14b19623fbe18b2730a" +content-hash = "cd2b1d6e9d5466d7701f1714986095d9bf84d2f6654d91676ffcfecb9754f1d7" diff --git a/pyproject.toml b/pyproject.toml index dfd2544..9ab9d8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ matplotlib = "^3.8.3" tarp = "^0.1.1" deprecation = "^2.1.0" scipy = "^1.15.0" +einops = "^0.8.1" +lightning = "^2.5.4" [tool.poetry.group.dev.dependencies] pre-commit = "^3.3.2" diff --git a/resources/saveddata/Hierarchicalpendulum_data_test_40.h5 b/resources/saveddata/Hierarchicalpendulum_data_test_40.h5 new file mode 100644 index 0000000..d303e63 Binary files /dev/null and b/resources/saveddata/Hierarchicalpendulum_data_test_40.h5 differ diff --git a/resources/savedmodels/hnpe_src/__init__.py b/resources/savedmodels/hnpe_src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resources/savedmodels/hnpe_src/diagnostics.py b/resources/savedmodels/hnpe_src/diagnostics.py new file mode 100644 index 0000000..dca7d41 --- /dev/null +++ b/resources/savedmodels/hnpe_src/diagnostics.py @@ -0,0 +1,406 @@ + +import os + +import numpy as np +import torch +import matplotlib.pyplot as plt +import pymc as pm +from pytorch_lightning import LightningDataModule, LightningModule +from einops import rearrange, repeat + +from scipy.stats import binom, gaussian_kde, iqr + +from scipy.stats import binom +import torch +from tqdm import tqdm +from pytorch_lightning import LightningDataModule, LightningModule + +import logging +module_logger = logging.getLogger('diagnostics') +logger = logging.getLogger('diagnostics.Diagnostics') +import arviz as az + +class Diagnostics(): + def __init__( + self, + trained_model:LightningModule, + dm: LightningDataModule, + local_samples=None, # shape (nbatch, nset, params) + global_samples=None, # shape (nbatch, nparams) + local_labels=None, + global_labels=None, + local_colors=None, + global_colors=None, + seed = 5, + outdir=None, + overwrite_saved_samples=False, + n_posterior_samples=None, + n_eval=None, + ): + self.trained_model = trained_model + self.dm = dm + outdir = outdir if outdir is not None else './outdir/' + if not os.path.exists(outdir): + os.makedirs(outdir) + self.outdir = outdir + self.overwrite_saved_samples = overwrite_saved_samples + if torch.cuda.is_available(): + self.device = torch.device('cuda') + logger.info('Using cuda') + else: + self.device = torch.device('cpu') + logger.info('Using cpu') + self.n_posterior_samples = 1000 if n_posterior_samples is None else n_posterior_samples + logger.info(f'Set number of posterior samples: {self.n_posterior_samples}') + self.n_eval = n_eval + # generate posterior samples if not input + if global_samples is not None and local_samples is not None: + logger.info('Using samples passed in') + self.local_samples = local_samples + self.global_samples = global_samples + else: + logger.info('Getting posterior samples') + self.local_samples, self.global_samples = self.trained_model.get_local_and_global_posterior_samples( + self.dm, + device1=self.device, + n_eval=self.n_eval, + save_dir=self.outdir, + n_samples=self.n_posterior_samples, + overwrite_if_exists=self.overwrite_saved_samples + ) + torch.save(self.local_samples, self.outdir + 'local_samples.pt') + torch.save(self.global_samples, self.dm.outdir + 'global_samples.pt') + + self.n_global_params = self.global_samples.shape[-1] + self.n_local_params = self.local_samples.shape[-1] + self.n_total_params = self.n_global_params + self.n_local_params + # make sure number of params in model equals number of params in samples + assert self.trained_model.hparams['n_global_params'] == self.n_global_params + assert self.trained_model.hparams['n_local_params'] == self.n_local_params + + # generate automatic labels and colors if not input + self.global_labels = global_labels if global_labels is not None else [f'global_{i}' for i in range(self.n_global_params)] + self.local_labels = local_labels if local_labels is not None else [f'local_{i}' for i in range(self.n_local_params)] + self.global_colors = global_colors if global_colors is not None else [f'C{i}' for i in range(self.n_global_params)] + self.local_colors = local_colors if local_colors is not None else [f'C{i}' for i in range(self.n_global_params, self.n_total_params)] + + self.x_test, self.y_local_test, self.y_global_test = dm.data_test.tensors + + # in case we need specific seed + self.seed = seed + + ## sort true global params and samples in increasing order and get means and stds + glob_srt = torch.argsort(self.y_global_test, dim=0) + self.y_global_test_sorted = torch.take_along_dim(self.y_global_test, glob_srt, dim=0) + self.global_samples_mean_sorted = torch.take_along_dim(self.global_samples.mean(dim=-2), glob_srt, dim=0) + self.global_samples_std_sorted = torch.take_along_dim(self.global_samples.std(dim=-2), glob_srt, dim=0) + + ## sort true local params and samples in increasing order and get means and stds + loc_srt = torch.argsort(self.y_local_test.flatten(0,1), dim=0) + self.y_local_test_sorted = torch.take_along_dim(self.y_local_test.flatten(0,1), loc_srt, dim=0) + self.local_samples_mean_sorted = torch.take_along_dim(self.local_samples.flatten(0,1).mean(dim=-2), loc_srt, dim=0) + self.local_samples_std_sorted = torch.take_along_dim(self.local_samples.flatten(0,1).std(dim=-2), loc_srt, dim=0) + + def plot_model_predictions(self, **kwargs): + + # set up plotting defaults if None input + kwargs['label_fontsize'] = kwargs['label_fontsize'] if 'label_fontsize' in kwargs else None + kwargs['legend_fontsize'] = kwargs['legend_fontsize'] if 'legend_fontsize' in kwargs else None + kwargs['tick_fontsize'] = kwargs['tick_fontsize'] if 'tick_fontsize' in kwargs else None + kwargs['marker'] = kwargs['marker'] if 'marker' in kwargs else 'o' + kwargs['edgecolor'] = kwargs['edgecolor'] if 'edgecolor' in kwargs else 'black' + kwargs['linewidth'] = kwargs['linewidth'] if 'linewidth' in kwargs else .5 + kwargs['capsize'] = kwargs['capsize'] if 'capsize' in kwargs else 0 + kwargs['alpha'] = kwargs['alpha'] if 'alpha' in kwargs else 1 + kwargs['s'] = kwargs['s'] if 's' in kwargs else 50 + kwargs['elinewidth'] = kwargs['elinewidth'] if 'elinewidth' in kwargs else .5 + + # global prediction plot + fig_glob, axs_glob = self._one_level_model_predictions( + self.y_global_test_sorted, + self.global_samples_mean_sorted, + self.global_samples_std_sorted, + self.n_global_params, + self.global_labels, + **kwargs + ) + + # local prediction plot + fig_loc, axs_loc = self._one_level_model_predictions( + self.y_local_test_sorted[::25], + self.local_samples_mean_sorted[::25], + self.local_samples_std_sorted[::25], + self.n_local_params, + self.local_labels, + **kwargs + ) + fig_glob.savefig(self.outdir + 'model_predictions_global.png') + fig_loc.savefig(self.outdir + 'model_predictions_local.png') + return (fig_glob, axs_glob), (fig_loc, axs_loc); + + def _one_level_model_predictions(self, y_true, y_samples_mean, y_samples_std, n_params, labels, **kwargs): + # plot + fig, axs = plt.subplots(2, n_params, sharex='col', figsize=(8*n_params, 16*n_params)) + + # plot one column per param + for i_param in range(n_params): + y_true = y_true[:, i_param].detach().numpy() + y_samples_mean = y_samples_mean[:, i_param].detach().numpy() + y_samples_std = y_samples_std[:, i_param].detach().numpy() + + # plot column + self._plot_one_col_model_pred( + y_true=y_true, + y_samples_mean=y_samples_mean, + y_samples_std=y_samples_std, + axs=axs[:, i_param] if n_params > 1 else axs, + fig=fig, + label=labels[i_param], + **kwargs + ) + fig.tight_layout() + return fig, axs + + def _plot_one_col_model_pred(self, + y_true, + y_samples_mean, + y_samples_std, + axs, + fig, + label, + **kwargs + ): + + self._plot_pairplot_model( + y_true, + y_samples_mean, + y_samples_std, + axs, + fig, + label, + **kwargs + ) + self._plot_residuals_model_pred( + y_true, + y_samples_mean, + y_samples_std, + axs, + label, + **kwargs + ) + + @staticmethod + def _plot_pairplot_model( + y_true, + y_samples_mean, + y_samples_std, + axs, + fig, + label, + **kwargs + ): + + ax = axs[0] + ax.plot(y_true, y_true, label='true', color='red', linestyle='--', zorder=1000) + ax.scatter( + y_true, + y_samples_mean, + marker=kwargs['marker'], + edgecolor=kwargs['edgecolor'], + linewidth=kwargs['linewidth'], + s=kwargs['s'], + ) + ax.errorbar( + y_true, + y_samples_mean, + yerr=y_samples_std, + fmt='none', + color='black', + elinewidth=kwargs['elinewidth'], + capsize=kwargs['capsize'], + alpha=kwargs['alpha'], + label=r'1$\sigma$', + ) + ax.legend(fontsize=kwargs['legend_fontsize']) + ax.set_ylabel(label, fontsize=kwargs['label_fontsize']) + ax.tick_params(labelsize=kwargs['tick_fontsize']) + + @staticmethod + def _plot_residuals_model_pred( + y_true, + y_samples_mean, + y_samples_std, + axs, + label, + **kwargs + ): + + ax = axs[1] + ax.scatter( + y_true, + y_samples_mean - y_true, + linewidth=kwargs['linewidth'], + marker=kwargs['marker'], + edgecolor=kwargs['edgecolor'], + s=kwargs['s'], + ) + ax.errorbar( + y_true, + y_samples_mean - y_true, + yerr=y_samples_std, + fmt='none', + color='black', + elinewidth=kwargs['elinewidth'], + capsize=kwargs['capsize'], + label=r'1$\sigma$', + ) + ax.hlines(0, xmin=y_true.min(), xmax=y_true.max(), color='red', linestyle='--', linewidth=2) + ax.legend(fontsize=kwargs['legend_fontsize']) + ax.set_xlabel(label, fontsize=kwargs['label_fontsize']) + ax.set_ylabel(label, fontsize=kwargs['label_fontsize']) + ax.tick_params(labelsize=kwargs['tick_fontsize']) + + def run_sbc(self, line_alpha=.8, uniform_region_alpha=.3, figsize=(4,4), fig_global=None, ax_global=None, fig_local=None, ax_local=None): + # global + glob_ranks = self._one_level_ranks(self.x_test, self.n_global_params, self.y_global_test, self.global_samples) + fig_glob, axs_glob = self._sbc_ecdf_rank_plot(glob_ranks, self.global_labels, self.global_colors, line_alpha, uniform_region_alpha, figsize=figsize, fig=fig_global, ax=ax_global) + + # local + loc_ranks = self._one_level_ranks(self.x_test, self.n_local_params, self.y_local_test, self.local_samples) + fig_loc, axs_loc = self._sbc_ecdf_rank_plot(loc_ranks, self.local_labels, self.local_colors, line_alpha, uniform_region_alpha, figsize=figsize, fig=fig_local, ax=ax_local) + + fig_glob.savefig(self.outdir + 'ECDF_global.png') + fig_loc.savefig(self.outdir + 'ECDF_local.png') + + @staticmethod + def _one_level_ranks(x, n_params, y_true, posterior_samples): + reduce_1d_fn = [eval(f"lambda theta, x: theta[:, {i}]") for i in range(n_params)] + n_set = x.shape[1] + x = x.flatten(0, 1) # (nbatch nset) ndata + if posterior_samples.shape[1] == n_set: #ie, if local var w/ shape (nbatch nset nsamp nparam) + y_true = y_true.flatten(0,1) + posterior_samples = posterior_samples.flatten(0,1) # flatten to get (nbatch nset) nsamp nparams + else: # else, is global with shape (nbatch nsamp nparam) + y_true = repeat(y_true, 'nbatch nparam -> (nbatch nset) nparam', nset=n_set) + posterior_samples = repeat(posterior_samples, 'nbatch nsamp nparam -> (nbatch nset) nsamp nparam', nset=n_set) + + # n_params = y_true.shape[-1] + n_sbc_runs = posterior_samples.shape[0] + + ranks = torch.zeros((n_sbc_runs, len(reduce_1d_fn))) + + # calculate ranks + for sbc_idx, (y_true0, x0) in tqdm( + enumerate(zip(y_true, x, strict=False)), + total=n_sbc_runs, + desc=f"Calculating ranks for {n_sbc_runs} sbc samples.", + ): + for dim_idx, reduce_fn in enumerate(reduce_1d_fn): + # rank posterior samples against true parameter, reduced to 1D. + ranks[sbc_idx, dim_idx] = ( + (reduce_fn(posterior_samples[sbc_idx, :, :], x0) < reduce_fn(y_true0.unsqueeze(0), + x0)).sum().item() + ) + return ranks + + def _sbc_ecdf_rank_plot(self, ranks, param_labels, colors, line_alpha=None, uniform_region_alpha=None, figsize=None, fig=None, ax=None): + # plot ranks + ranks_list = [ranks] + n_sbc_runs, num_parameters = ranks_list[0].shape + + n_bins = n_sbc_runs // 20 + num_repeats = 50 + + if fig is None or ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) + plt.sca(ax) + + ranki = ranks_list[0] + for jj in range(num_parameters): + self._plot_ranks_as_cdf( + ranki[:, jj], # type: ignore + n_bins, + num_repeats, + ranks_label=param_labels[jj], + color=colors[jj], + xlabel="posterior rank", + # Plot ylabel and legend at last. + show_ylabel = jj == (num_parameters - 1), + alpha=line_alpha, + ) + self._plot_cdf_region_expected_under_uniformity( + n_sbc_runs, + n_bins, + num_repeats, + alpha=uniform_region_alpha, + ) + # show legend on the last subplot. + # plt.legend(**legend_kwargs) + plt.legend() + return fig, ax # pyright: ignore[reportReturnType] + + @staticmethod + def _plot_ranks_as_cdf( + ranks, + n_bins, + n_repeats, + ranks_label, + xlabel, + color, + alpha=.8, + show_ylabel=True, + num_ticks=3 + ): + hist, *_ = np.histogram(ranks, bins=n_bins, density=False) + # Construct empirical CDF. + histcs = hist.cumsum() + # Plot cdf and repeat each stair step + plt.plot( + np.linspace(0, n_bins, n_repeats * n_bins), + np.repeat(histcs / histcs.max(), n_repeats), + label=ranks_label, + color=color, + alpha=alpha, + ) + + if show_ylabel: + plt.yticks(np.linspace(0, 1, 3)) + plt.ylabel("empirical CDF") + else: + # Plot ticks only + plt.yticks(np.linspace(0, 1, 3), []) + + plt.ylim(0, 1) + plt.xlim(0, n_bins) + plt.xticks(np.linspace(0, n_bins, num_ticks)) + plt.xlabel("posterior rank" if xlabel is None else xlabel) + + @staticmethod + def _plot_cdf_region_expected_under_uniformity( + n_sbc_samples, + n_bins, + n_repeats, + alpha=.2, + color='grey' + ): + + # Construct uniform histogram. + uni_bins = binom(n_sbc_samples, p=1 / n_bins).ppf(0.5) * np.ones(n_bins) + uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum() + # Decrease value one in last entry by epsilon to find valid + # confidence intervals. + uni_bins_cdf[-1] -= 1e-9 + + lower = [binom(n_sbc_samples, p=p).ppf(0.005) for p in uni_bins_cdf] + upper = [binom(n_sbc_samples, p=p).ppf(0.995) for p in uni_bins_cdf] + + # Plot grey area with expected ECDF. + plt.fill_between( + x=np.linspace(0, n_bins, n_repeats * n_bins), + y1=np.repeat(lower / np.max(lower), n_repeats), + y2=np.repeat(upper / np.max(upper), n_repeats), # pyright: ignore[reportArgumentType] + color=color, + alpha=alpha, + label="expected under uniformity", + ) + diff --git a/resources/savedmodels/hnpe_src/flows.py b/resources/savedmodels/hnpe_src/flows.py new file mode 100644 index 0000000..4e0ea15 --- /dev/null +++ b/resources/savedmodels/hnpe_src/flows.py @@ -0,0 +1,117 @@ +from functools import partial + +import torch +from torch import nn, tanh, relu + +from nflows import distributions as distributions_ +from nflows import flows, transforms +from nflows.nn import nets + + +def build_maf(dim=1, + num_transforms=8, + context_features=None, + hidden_features=128, + num_blocks=2, + use_residual_blocks=False, + random_mask=False, + activation_function=torch.relu, + dropout_probability=0.0, + use_batch_norm=False): + transform = transforms.CompositeTransform( + [ + transforms.CompositeTransform( + [ + transforms.MaskedAffineAutoregressiveTransform( + features=dim, + hidden_features=hidden_features, + context_features=context_features, + num_blocks=num_blocks, + use_residual_blocks=use_residual_blocks, + random_mask=random_mask, + activation=activation_function, + dropout_probability=dropout_probability, + use_batch_norm=use_batch_norm, + ), + transforms.RandomPermutation(features=dim), + ] + ) + for _ in range(num_transforms) + ] + ) + + distribution = distributions_.StandardNormal((dim,)) + neural_net = flows.Flow(transform, distribution) + + return neural_net + + +def create_alternating_binary_mask(features, even=True): + """ + Creates a binary mask of a given dimension which alternates its masking. + :param features: Dimension of mask. + :param even: If True, even values are assigned 1s, odd 0s. If False, vice versa. + :return: Alternating binary mask of type torch.Tensor. + """ + mask = torch.zeros(features).byte() + start = 0 if even else 1 + mask[start::2] += 1 + return mask + + +def mask_in_layer(i, features): + return create_alternating_binary_mask(features=features, even=(i % 2 == 0)) + + +def build_nsf(dim=1, num_transforms=8, context_features=None, hidden_features=128, tail_bound=3.0, num_bins=10): + conditioner = partial( + nets.ResidualNet, + hidden_features=hidden_features, + context_features=context_features, + num_blocks=2, + activation=relu, + dropout_probability=0.0, + use_batch_norm=False, + ) + + # Stack spline transforms. + transform_list = [] + for i in range(num_transforms): + block = [ + transforms.PiecewiseRationalQuadraticCouplingTransform( + mask=mask_in_layer(i, dim), + transform_net_create_fn=conditioner, + num_bins=num_bins, + tails="linear", + tail_bound=tail_bound, + apply_unconditional_transform=False, + ) + ] + + # Add LU transform only for high D x. Permutation makes sense only for more than + # one feature. + block.append( + transforms.LULinear(dim, identity_init=True), + ) + transform_list += block + + distribution = distributions_.StandardNormal((dim,)) + + # Combine transforms. + transform = transforms.CompositeTransform(transform_list) + neural_net = flows.Flow(transform, distribution) + + return neural_net + + +def build_mlp(input_dim, hidden_dim, output_dim, layers): + """Create a MLP from the configurations""" + + activation = nn.GELU + + seq = [nn.Linear(input_dim, hidden_dim), activation()] + for _ in range(layers): + seq += [nn.Linear(hidden_dim, hidden_dim), activation()] + seq += [nn.Linear(hidden_dim, output_dim)] + + return nn.Sequential(*seq) diff --git a/resources/savedmodels/hnpe_src/neural_nets.py b/resources/savedmodels/hnpe_src/neural_nets.py new file mode 100644 index 0000000..d3f47e1 --- /dev/null +++ b/resources/savedmodels/hnpe_src/neural_nets.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn + +from einops import rearrange, repeat +import numpy as np +from tqdm import tqdm + +import os +# import sys +# import gc +import logging +module_logger = logging.getLogger('diagnostics') +logger = logging.getLogger('diagnostics.neural_net') + +from resnet import ResNetEstimator +from flows import build_mlp, build_maf + +import pytorch_lightning as pl +# from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping + +torch.set_default_dtype(torch.float32) + +class HierarchicalDeepSet(nn.Module): + """ Backbone to the hierarchical deep set model, using a ResNet embedder and MAF flows for + local and global parameter posterior density estimators. + """ + def __init__(self, n_local_params, n_global_params, n_transforms, n_set_max, dim_hidden=128, obs_shape=None, condition_local_on_global=True): + super(HierarchicalDeepSet, self).__init__() + + inference_net_kwargs = {"cfg":50} + self.n_set_max = n_set_max + + obs_dim = len(obs_shape) + if obs_dim == 1: + print('Observation data is 1d: using mlp encoder') + self.enc = build_mlp(input_dim=obs_shape[0], hidden_dim=int(2 * dim_hidden), output_dim=dim_hidden, layers=4) + else: + print('Observation data is 2d: using ResNet encoder') + self.enc = ResNetEstimator(n_out=dim_hidden, **inference_net_kwargs) + self.dec = build_mlp(input_dim=int(dim_hidden / 2) + 1, hidden_dim=int(2 * dim_hidden), output_dim=int(dim_hidden / 2), layers=4).float() + + # Condition local flow on global params if local loss is turned on + extra_context = np.max([n_local_params - n_global_params, 1, n_global_params]) + self.condition_local_on_global = condition_local_on_global + + self.flow_local = build_maf(dim=n_local_params, num_transforms=n_transforms, context_features=int(dim_hidden / 2) + extra_context, hidden_features=int(2 * dim_hidden)).float() + self.flow_global = build_maf(dim=n_global_params, num_transforms=n_transforms, context_features=int(dim_hidden / 2), hidden_features=int(2 * dim_hidden)).float() + + def forward(self, x, y_local, y_global): + batch_size = x.shape[0] + + lens = torch.randint(low=1, high=self.n_set_max + 1,size=(batch_size,), dtype=torch.float32) + mask = (torch.arange(self.n_set_max).expand(len(lens), self.n_set_max) < torch.Tensor(lens)[:,None]).to(x.device) + + x = rearrange(x, "batch n_set l -> (batch n_set) l", n_set=self.n_set_max) + x = self.enc(x) + + x = rearrange(x, "(batch n_set) n_out -> batch n_set n_out", n_set=self.n_set_max) + + idx_setperm = torch.randperm(self.n_set_max) # Permutation indices + x = x[:, idx_setperm, :] * mask[:, :, None] # Permute set elements and mask + y_local = y_local[:, idx_setperm, :] + + x, x_cond_local = torch.chunk(x, 2, -1) + + x = x.sum(-2) / mask.sum(1)[:, None] + + x = torch.cat([x, lens[:, None].to(x.device)], -1) # Add cardinality for rho network + x_cond_global = self.dec(x) + + x_cond_local = rearrange(x_cond_local, "batch n_set n_out -> (batch n_set) n_out", n_set=self.n_set_max) + + if self.condition_local_on_global: + y_global_repeat = repeat(y_global, "batch glob -> (batch n_set) glob", n_set=self.n_set_max) + x_cond_local = torch.cat([x_cond_local, y_global_repeat], -1) + + y_local = rearrange(y_local, "batch n_set n_param -> (batch n_set) n_param", n_set=self.n_set_max) + + log_prob_local = self.flow_local.log_prob(y_local, x_cond_local) + log_prob_local = rearrange(log_prob_local, "(batch n_set) -> batch n_set", n_set=self.n_set_max) + log_prob_local = (log_prob_local * mask).sum(-1) + + log_prob_global = self.flow_global.log_prob(y_global, x_cond_global) + + return log_prob_local, log_prob_global + + +class HierarchicalDeepSetInference(pl.LightningModule): + """ Hierarchical deep set lightning module for training and inference. + """ + + def __init__(self, + optimizer=torch.optim.AdamW, + optimizer_kwargs=None, + lr=3e-4, + scheduler=torch.optim.lr_scheduler.CosineAnnealingLR, + scheduler_kwargs=None, + local_loss=True, + global_loss=True, + n_local_params=1, + n_global_params=1, + n_transforms=6, + n_set_max=25, + dim_hidden=128, + obs_shape=None, + **kwargs, + ): + super().__init__() + + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs + self.lr = lr + self.obs_shape = obs_shape + self.local_loss = local_loss + self.global_loss = global_loss + self.n_set_max = n_set_max + self.n_local_params = n_local_params + self.n_global_params = n_global_params + condition_local_on_global = True if (local_loss and global_loss) else False + + self.deep_set = HierarchicalDeepSet( + n_local_params=n_local_params, + n_global_params=n_global_params, + n_transforms=n_transforms, + n_set_max=n_set_max, + dim_hidden=dim_hidden, + obs_shape=obs_shape, + condition_local_on_global=condition_local_on_global + ) + self.save_hyperparameters() + + def test_step(self, batch, batch_idx): + return {'test_local_loss': log_prob_local, 'test_global_loss':log_prob_global, 'samples_y_local': samples_local, 'samples_y_global': samples_global, 'y_local_true': y_local, 'y_global_true': y_global} + + def forward(self, x, y_local, y_global): + log_prob = self.deep_set(x, y_local, y_global) + return log_prob + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters(), lr=self.lr, **self.optimizer_kwargs) + return {"optimizer": optimizer, + "lr_scheduler": { + "scheduler": self.scheduler(optimizer, **self.scheduler_kwargs), + "interval": "epoch", + "monitor": "val_loss", + "frequency": 1} + } + + def training_step(self, batch, batch_idx): + x, y_local, y_global = batch + log_prob_local, log_prob_global = self(x, y_local, y_global) + log_prob = torch.zeros_like(log_prob_local).to(log_prob_local.device) + if self.local_loss: + log_prob += log_prob_local + if self.global_loss: + log_prob += log_prob_global + loss = -log_prob.mean() + self.log('train_loss', loss, on_epoch=True) + self.log('local_train_loss', -log_prob_local.mean(), on_epoch=True) + self.log('global_train_loss', -log_prob_global.mean(), on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y_local, y_global = batch + log_prob_local, log_prob_global = self(x, y_local, y_global) + log_prob = torch.zeros_like(log_prob_local).to(log_prob_local.device) + if self.local_loss: + log_prob += log_prob_local + if self.global_loss: + log_prob += log_prob_global + loss = -log_prob.mean() + self.log('val_loss', loss, on_epoch=True) + self.log('val_local_loss', -log_prob_local.mean(), on_epoch=True) + self.log('val_global_loss', -log_prob_global.mean(), on_epoch=True) + return loss + + def get_deepset(self): + return self.deep_set + + + def get_local_and_global_posterior_samples(self, dm, device1='cpu', n_eval=None, n_samples=1000, save_dir=None, overwrite_if_exists=False): + global_samples = self.get_global_posterior_samples( + dm, + device1=device1, + n_eval=n_eval, + n_samples=n_samples, + save_dir=save_dir, + overwrite_if_exists=overwrite_if_exists, + ) + local_samples = self.get_local_posterior_samples( + dm, + device1=device1, + n_eval=n_eval, + n_samples=n_samples, + save_dir=save_dir, + overwrite_if_exists=overwrite_if_exists, + ) + return local_samples, global_samples + + def get_global_posterior_samples(self, dm, device1='cpu', n_eval=None, n_samples=1000, batch_size=50, save_dir=None, overwrite_if_exists=False): + filename='global_samples.pt' + if save_dir is None: + save_dir = dm.data_dir + + if os.path.exists(save_dir + filename) and not overwrite_if_exists: + logger.info(f'Loading global samples from {save_dir + filename}') + return torch.load(save_dir + filename) + + if not os.path.exists(save_dir + filename): + logger.info(f'Will save global samples to {save_dir + filename}') + + # x, y_local, y_global = dm.data_test.tensors + if n_eval is None: + n_eval = self.n_set_max + + x, y_local, y_global = dm.data_test.tensors + device2 = 'cpu' + x = rearrange(x[:, :n_eval], "batch n_set l -> (batch n_set) l", n_set=n_eval) + x = self.deep_set.enc.to(device1)(x.to(device1)) + x = rearrange(x, "(batch n_set) n_out -> batch n_set n_out", n_set=n_eval) + + x_cond_global, x_cond_local = torch.chunk(x, 2, -1) + x_cond_global = x_cond_global.mean(-2) + col_n_eval = torch.full(size=(len(x_cond_global),), fill_value=n_eval)[:, None] + x_cond_global = torch.cat([x_cond_global.to(device2), col_n_eval], dim=-1) + x_cond_global = self.deep_set.dec.to(device2)(x_cond_global.to(device2)) + + global_samples = self.deep_set.flow_global.sample(num_samples=n_samples, context=x_cond_global) + torch.save(global_samples, save_dir + filename) + return global_samples + + def get_local_posterior_samples(self, dm, device1='cpu', n_eval=None, n_samples=1000, batch_size=50, save_dir=None, overwrite_if_exists=True): + if save_dir is None: + save_dir = dm.data_dir + filename='local_samples.pt' + + if os.path.exists(save_dir + filename) and not overwrite_if_exists: + logger.info(f'Loading local samples from {save_dir + filename}') + return torch.load(save_dir + filename) + + if not os.path.exists(save_dir + filename): + logger.info(f'Will save local samples to {save_dir + filename}') + + x, y_local, y_global = dm.data_test.tensors + if n_eval is None: + n_eval = self.n_set_max + + device2 = 'cpu' + x, y_local, y_global = dm.data_test.tensors + x = rearrange(x[:, :n_eval], "batch n_set l -> (batch n_set) l", n_set=n_eval) + x = self.deep_set.enc.to(device1)(x.to(device1)) + x = rearrange(x, "(batch n_set) n_out -> batch n_set n_out", n_set=n_eval) + + x_cond_global, x_cond_local = torch.chunk(x, 2, -1) + global_samples = self.get_global_posterior_samples(dm, device1, n_eval=n_eval, n_samples=n_samples, save_dir=save_dir) + + n_batches = n_samples // batch_size + x_cond_local_global = torch.cat([x_cond_local.to(device2), self._get_global_samples_mean(global_samples, n_eval)], -1) + # x_cond_local_global = rearrange(x_cond_local_global, "batch nset nout -> (batch nset) nout", nset=n_eval) + + local_samples = torch.empty(size=(len(x_cond_local_global), n_eval, n_samples, self.n_local_params)) + for i_batch, x_batch in tqdm(enumerate(x_cond_local_global)): + torch.save(self.deep_set.flow_local.sample(num_samples=n_samples, context=x_cond_local_global[i_batch]), + save_dir + filename[:-3] + f'_{i_batch}.pt') + for i_batch, x_batch in tqdm(enumerate(x_cond_local_global)): + file_path = save_dir + filename[:-3] + f'_{i_batch}.pt' + local_samples[i_batch] = torch.load(file_path) + os.remove(file_path) + torch.save(local_samples, save_dir + 'local_samples.pt') + return local_samples + + @staticmethod + def _get_global_samples_mean(global_samples, n_eval): + global_samples_mean = repeat(global_samples.mean(dim=1),"batch nparam -> batch n_eval nparam", n_eval=n_eval) + return global_samples_mean + diff --git a/resources/savedmodels/hnpe_src/plotting.py b/resources/savedmodels/hnpe_src/plotting.py new file mode 100644 index 0000000..a7f66af --- /dev/null +++ b/resources/savedmodels/hnpe_src/plotting.py @@ -0,0 +1,30 @@ + + + +''' +Want: +- random samples of sbi data plotted +- histogram of each measurement distribution +- +''' + +def plot_random_image(sim_dict): + fontsize = 18 + if len(sim_dict['x'].shape) == 2: + fig, axs = plt.subplots(2, 2, figsize=(8, 8), sharex='col', sharey='row') + for i in range(2): + for j in range(2): + batch_idx, set_idx = np.random.randint(0, len(y_global)), np.random.randint(0, len(y_local[0])) + axs[i, j].imshow(x[batch_idx, set_idx]) + for i_param in range(y_global.shape[-1]): + glob_val = y_global[batch_idx, i_param] * y_global_std[i_param] + y_global_mean[i_param] + line = f"{config_dict['global_labels'][i_param]}: {glob_val.item():.2f}" + axs[i, j].annotate(line, xy=(2, 2 * (i_param + 2)), fontsize=fontsize, color='white') + + for i_param in range(y_local.shape[-1]): + loc_val = y_local[batch_idx, i_param] * y_local_std[i_param] + y_local_mean[i_param] + line = f"{config_dict['local_labels'][i_param]}: {loc_val[i_param].item():.2f}" + axs[i, j].annotate(line, xy=(2, 21 + 3 * (i_param + 2)), fontsize=fontsize, color='white') + elif len(sim_dict['x']) + fig.tight_layout() + fig.savefig(fig_outdir + 'sample_images.png') diff --git a/resources/savedmodels/hnpe_src/resnet.py b/resources/savedmodels/hnpe_src/resnet.py new file mode 100644 index 0000000..ae55333 --- /dev/null +++ b/resources/savedmodels/hnpe_src/resnet.py @@ -0,0 +1,230 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import torch +from torch.autograd import grad +import logging + +#from einops import rearrange + +logger = logging.getLogger(__name__) + + +import torch.nn as nn + +import sys + +sys.path.append("../") + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, planes) + self.bn1 = norm_layer(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn2 = norm_layer(planes) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNetEstimator(nn.Module): + def __init__(self, n_aux=0, cfg=18, n_hidden=512, n_out=128, input_mean=None, input_std=None, log_input=False, zero_init_residual=False, norm_layer=None, zero_bias=False): + super(ResNetEstimator, self).__init__() + + self.input_mean = input_mean + self.input_std = input_std + self.log_input = log_input + + block, layers = self._load_cfg(cfg) + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.inplanes = 64 + self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc1 = nn.Linear(n_hidden * block.expansion + n_aux, 2048) + self.fc2 = nn.Linear(2048, n_out) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if zero_bias: + try: + nn.init.constant_(m.bias, 0) + except AttributeError: + pass + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + if zero_bias: + try: + nn.init.constant_(m.bias, 0) + except AttributeError: + pass + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def forward(self, x, x_aux=None): + # Preprocessing + h = self._preprocess(x) + + # h_t = rearrange(h, 'b c h w -> b c w h') + # h = torch.cat([h, h_t], dim=1) + + # ResNet + h = self.conv1(h) + h = self.bn1(h) + h = self.relu(h) + h = self.maxpool(h) + + h = self.layer1(h) + h = self.layer2(h) + h = self.layer3(h) + h = self.layer4(h) + + h = self.avgpool(h) + h = h.view(h.size(0), -1) + if x_aux is not None: + h = torch.cat((h, x_aux), 1) + h = self.fc1(h) + h = self.relu(h) + h = self.fc2(h) + + return h + + @staticmethod + def _load_cfg(cfg): + if cfg == 18: + block = BasicBlock + layers = [2, 2, 2, 2] + elif cfg == 34: + block = BasicBlock + layers = [3, 4, 6, 3] + elif cfg == 50: + block = Bottleneck + layers = [3, 4, 6, 3] + elif cfg == 101: + block = Bottleneck + layers = [3, 4, 23, 3] + elif cfg == 152: + block = Bottleneck + layers = [3, 8, 36, 3] + else: + raise ValueError("Unknown ResNet configuration {}, use 18, 34, 50, 101, or 152!".format(cfg)) + + return block, layers + + def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): + if norm_layer is None: + norm_layer = nn.BatchNorm2d + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion)) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _preprocess(self, x): + if self.log_input: + x = torch.log(1.0 + x) + if self.input_mean is not None and self.input_std is not None: + x = x - self.input_mean + x = x / self.input_std + x = x.unsqueeze(1) + return x diff --git a/resources/savedmodels/hnpe_src/toy_data.py b/resources/savedmodels/hnpe_src/toy_data.py new file mode 100644 index 0000000..68220c0 --- /dev/null +++ b/resources/savedmodels/hnpe_src/toy_data.py @@ -0,0 +1,275 @@ +import numpy as np +import sys +import os +from abc import ABC, abstractmethod, abstractstaticmethod + +import pytorch_lightning as pl +import torch +from tensorflow_probability.substrates.jax.internal.test_util import disable_test_for_backend +from torch.utils.data import DataLoader, random_split, TensorDataset +from torch.distributions import Uniform, Normal, HalfNormal, Exponential +# from torch.distributions.uniform import Uniform +# from torch.distributions.normal import Normal +# from torch.distributions.half_normal import HalfNormal +# from torch.distributions.exponential import Exponential +import matplotlib.pyplot as plt + +import logging +module_logger = logging.getLogger('diagnostics') +logger = logging.getLogger('diagnostics.toy_data') + +@np.vectorize +def normalize(data, mean, std): + return (data - mean)/std + +class BaseDataModule(pl.LightningDataModule, ABC): + def __init__(self, + loader_batch_size=None, + exp_dir = None, + relative_data_dir = None, + train_fn = None, + test_fn = None, + transform = None, + generator_seed_1 = 32, + generator_seed_2 = 25, + n_batch = 20_000, + n_set_max = 25, + n_features = 5, + device = 'cpu', + local_labels = None, + global_labels = None, + test_data_seed = 150, + SIMULATOR_INDEX = None, + use_hyperpriors=False, + prior_type='uniform', + ): + super().__init__() + self.loader_batch_size = loader_batch_size + self.global_labels = global_labels + self.local_labels = local_labels + self.exp_dir = exp_dir + self.data_dir = exp_dir + relative_data_dir + self.train_data_path = exp_dir + train_fn + self.test_data_path = exp_dir + test_fn + self.transform = transform + self.n_batch = n_batch + self.n_set_max = n_set_max + self.n_features = n_features + self.generator_seed_1 = generator_seed_1 + self.generator_seed_2 = generator_seed_2 + self.pin_memory = False if device == 'cpu' else True + self.test_data_seed = test_data_seed + self.SIMULATOR_INDEX = SIMULATOR_INDEX + self.use_hyperpriors = use_hyperpriors + self.prior_type = prior_type + + @staticmethod + def get_times(n_features): + t = torch.linspace(1, 2, n_features) + return t + + @abstractmethod + def simulator(self, **kwargs): + raise NotImplementedError + + @abstractmethod + def simulate_new_dataset(self, n_batch=None, n_sets=None, n_features=5, priors='uniform', **kwargs): + raise NotImplementedError + + def prepare_data(self, n_test_batch=100, overwrite_saved_test_data=False): + # Prepare train data + # if train data already exists, load. Otherwise, simulate and save + if os.path.exists(self.train_data_path): + logger.info('Will use train data saved in ' + self.train_data_path) + else: + logger.info(f'Simulating new train data') + torch.save( + self.simulate_new_dataset(n_batch=self.n_batch, n_set=self.n_set_max, n_features=self.n_features), + self.train_data_path + ) + # Prepare test data + if os.path.exists(self.test_data_path): + test_data = torch.load(self.test_data_path, weights_only=False) + num_test_data_points_unequal = len(test_data['x']) != n_test_batch + if num_test_data_points_unequal: + logger.info( + f'Number of desired test data points is not equal to number in saved test_data file' + ) + if num_test_data_points_unequal or overwrite_saved_test_data: + if num_test_data_points_unequal: + logger.info('The saved test data set does not have the desired number of test batches. /n' + 'Re-simulating test dataset' + ) + if overwrite_saved_test_data: + logger.info(f'Overwriting saved test data set with seed {self.test_data_seed}') + torch.manual_seed(self.test_data_seed) + logger.info(f'Simulating new test dataset and saving in {self.test_data_path}') + torch.save(self.simulate_new_dataset( + n_batch=n_test_batch, + n_set=self.n_set_max, + n_features=self.n_features), + self.test_data_path + ) + else: + logger.info(f'No test data found in {self.test_data_path}. Simulating new test dataset') + torch.save( + self.simulate_new_dataset(n_batch=n_test_batch, n_set=self.n_set_max, n_features=self.n_features), + self.test_data_path + ) + + def setup(self, stage='test'): + '''Assign train, val, and test data''' + generator1 = torch.Generator().manual_seed(self.generator_seed_1) + if stage == "fit": + logger.info(f'dm.setup in "fit" stage. Loading train data from {self.train_data_path}') + data_full_dict = torch.load(self.train_data_path, weights_only=False) + x_full = data_full_dict['x'] + y_local_full = data_full_dict['y_local'] + y_global_full = data_full_dict['y_global'] + dataset_full = TensorDataset(x_full, y_local_full, y_global_full) + self.data_train, self.data_val = random_split(dataset_full, [.9, .1], generator=generator1) + if stage == "test": + logger.info(f'dm.setup in "test" stage. Loading test data from {self.test_data_path}') + data_test_dict = torch.load(self.test_data_path, weights_only=False) + dataset_test = TensorDataset(data_test_dict['x'], data_test_dict['y_local'], data_test_dict['y_global']) + logger.info('Setting dm.data_test to loaded test data') + self.data_test = dataset_test + + def train_dataloader(self): + return DataLoader(self.data_train, + batch_size=self.loader_batch_size, + num_workers=4, + pin_memory=self.pin_memory, + shuffle=True + ) + + def val_dataloader(self): + return DataLoader(self.data_val, + batch_size=self.loader_batch_size, + num_workers=4, + pin_memory=self.pin_memory, + shuffle=False + ) + def test_dataloader(self): + return DataLoader(self.data_test, + batch_size=self.loader_batch_size, + num_workers=8, + pin_memory=self.pin_memory, + shuffle=False + ) + + def plot_random_train_samples(self, n_points=10, save_dir=None): + batch_idxs = np.random.randint(low=0, high=self.n_batch, size=(n_points)) + set_idxs = np.random.randint(low=0, high=self.n_set_max, size=(n_points)) + sim_dict = torch.load(self.train_data_path, weights_only=False) + fig, ax = plt.subplots(1) + for i, (b_idx, set_idx) in enumerate(zip(batch_idxs, set_idxs)): + ax.scatter(sim_dict['t'], sim_dict['x'][b_idx, set_idx], color=f'C{i}') + ax.plot(sim_dict['t'], sim_dict['x_true'][b_idx, set_idx], color=f'C{i}') + ax.set_xlabel('t') + ax.set_ylabel('x') + if save_dir is not None: + fig.savefig(self.exp_dir + 'random_train_samples_plot.png') + +class VaryingSlopeInterceptDM(BaseDataModule): + def __init__( + self, + loader_batch_size, + exp_dir = './', + relative_data_dir='', + train_fn='VSI_train_data.pt', + test_fn='VSI_test_data.pt', + transform=None, + generator_seed_1 = 32, + generator_seed_2 = 25, + n_batch = 20_000, + n_set_max = 25, + n_features = 5, + device='cpu', + test_data_seed=150, + use_hyperpriors=False, + prior_type='uniform', + ): + global_labels = ['b (intercept)'] + local_labels = ['m (slope)'] + + super().__init__(loader_batch_size, + exp_dir, + relative_data_dir, + train_fn, + test_fn, + transform, + generator_seed_1, + generator_seed_2, + n_batch, + n_set_max, + n_features, + device, + local_labels, + global_labels, + test_data_seed, + 0, + use_hyperpriors, + prior_type=prior_type, + ) + + # if not os.path.exists(self.train_data_path): + # raise FileNotFoundError(f'{self.train_data_path} does not exist') + # if not os.path.exists(self.test_data_path): + # raise FileNotFoundError(f'{self.test_data_path} does not exist') + + def simulator(self, t, m, b): + if b.shape != m.shape: + b = b[:, np.newaxis, :] + return m * t + b + + def simulate_new_dataset(self, n_batch=None, n_set=None, n_features=5, prior_type=None): + data_dict = {} + if prior_type is None: + prior_type = self.prior_type + if self.use_hyperpriors: + if prior_type == 'normal': + b_hyperpriors = [Normal(0, 2), HalfNormal(3)] # mu, sigma + b_hyperparams = [hyperprior.sample((n_batch,)) for hyperprior in b_hyperpriors] + b_prior = Normal(loc=b_hyperparams[0], scale=b_hyperparams[1]) + + m_hyperpriors = [Normal(0, 2), HalfNormal(3)] # mu, sigma + m_hyperparams = [hyperprior.sample((n_batch, n_set,)) for hyperprior in m_hyperpriors] + m_prior = Normal(loc=m_hyperparams[0], scale=m_hyperparams[1]) + elif prior_type == 'uniform': + b_hyperpriors = [Uniform(-3, -1), Uniform(1,3)] # mu, sigma + b_hyperparams = [hyperprior.sample((n_batch,)) for hyperprior in b_hyperpriors] + b_prior = Normal(loc=b_hyperparams[0], scale=b_hyperparams[1]) + + m_hyperpriors = [Normal(0, 2), HalfNormal(3)] # mu, sigma + m_hyperparams = [hyperprior.sample((n_batch, n_set,)) for hyperprior in m_hyperpriors] + m_prior = Normal(loc=m_hyperparams[0], scale=m_hyperparams[1]) + + m = m_prior.sample((1,)).reshape((n_batch, n_set, 1)) + b = b_prior.sample((1,)).reshape((n_batch, 1)) + + data_dict.update({'global_hyperparams': b_hyperparams, + 'local_hyperparams': m_hyperparams, + 'global_prior': b_prior, + 'local_prior': m_prior, + 'global_hyperpriors': b_hyperpriors, + 'local_hyperpriors': m_hyperpriors}) + else: + if prior_type == 'uniform': + m_prior = Uniform(low=-2, high=2) + b_prior = Uniform(low=-2, high=2) + elif prior_type == 'normal': + m_prior = Normal(loc=0, scale=4) + b_prior = Normal(loc=0, scale=4) + m = m_prior.sample((n_batch, n_set, 1)) + b = b_prior.sample((n_batch, 1)) + sigma_x_true = torch.randn(size=(n_batch, n_set, n_features)) + sigma_x_true /= 10 + t = self.get_times(n_features) + x_true = self.simulator(t, m, b) + x = x_true + sigma_x_true + data_dict.update({'x': x, 'x_true': x_true, 'y_local': m, 'y_global': b, + 'sigma_x_true': sigma_x_true, 't': t} + ) + return data_dict + diff --git a/resources/savedmodels/trained_model_0909.pkl b/resources/savedmodels/trained_model_0909.pkl new file mode 100644 index 0000000..07e70fc Binary files /dev/null and b/resources/savedmodels/trained_model_0909.pkl differ diff --git a/src/deepdiagnostics/data/__init__.py b/src/deepdiagnostics/data/__init__.py index 43a04a9..a584410 100644 --- a/src/deepdiagnostics/data/__init__.py +++ b/src/deepdiagnostics/data/__init__.py @@ -1,4 +1,4 @@ -from deepdiagnostics.data.h5_data import H5Data +from deepdiagnostics.data.h5_data import H5Data, H5HierarchyData from deepdiagnostics.data.pickle_data import PickleData -DataModules = {"H5Data": H5Data, "PickleData": PickleData} +DataModules = {"H5Data": H5Data, "PickleData": PickleData, "H5HierarchyData": H5HierarchyData} diff --git a/src/deepdiagnostics/data/data.py b/src/deepdiagnostics/data/data.py index ad9722b..c1adc3a 100644 --- a/src/deepdiagnostics/data/data.py +++ b/src/deepdiagnostics/data/data.py @@ -5,6 +5,10 @@ from deepdiagnostics.data.lookup_table_simulator import LookupTableSimulator from deepdiagnostics.utils.simulator_utils import load_simulator +# in hierarchy data, there are two n_dims, n_local and n_global which are summed to give n_dims. +# they can be accessed as data["n_local"] and data["n_global"] +# in this data class, n_dims is in __init__. would that be an issue? + class Data: """ Load stored data to use in diagnostics @@ -50,7 +54,8 @@ def __init__( self.thetas = self._thetas() self.prior_dist = self.load_prior(prior, prior_kwargs) - self.n_dims = self.thetas.shape[1] + # Uncomment this for NPE + # self.n_dims = self.thetas.shape[1] self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) self.simulator_outcome = self._simulator_outcome() diff --git a/src/deepdiagnostics/data/h5_data.py b/src/deepdiagnostics/data/h5_data.py index 81d834f..2785852 100644 --- a/src/deepdiagnostics/data/h5_data.py +++ b/src/deepdiagnostics/data/h5_data.py @@ -116,3 +116,125 @@ def get_sigma_true(self): return super().get_sigma_true() except (AssertionError, KeyError): return 1 + + +class H5HierarchyData(Data): + """ + Load data that has been saved in a h5 format. + + If you cast your problem to be a pendulum simulation with y = L * sin(sqrt(g/L) * x), these are the fields required and what they represent: + + simulator_outcome e.g. positions - y + thetas - parameters of the model - L, g + context e.g. time - xs + + metadata fields: + num_global - number of global samples + num_local - number of local samples + n_global - number of global parameters + n_local - number of local parameters + + + .. attribute:: Data Parameters + + :ys: [REQUIRED] The outcomes of the simulator. The posterior of the thetas is evaluated at these outcomes. + :thetas: [REQUIRED] The theta, the true parameters of the external model that generated the outcomes. + :xs: [REQUIRED] The context, the known inputs to the simulator. + :num_global: [REQUIRED] Number of global samples e.g. num_global = 200 means 200 different settings of the global parameters. + :num_local: [REQUIRED] Number of local samples e.g. num_local = 50 means 50 different settings of the local parameters for each global parameter setting. + :n_global: [REQUIRED] Number of global parameters e.g. n_global = 1 means 1 global parameter (like g in the pendulum example). + :n_local: [REQUIRED] Number of local parameters e.g. n_local = 1 means 1 local parameter (like L in the pendulum example). + :simulator_dimensions: [REQUIRED + :simulator: [OPTIONAL] A simulator function that takes in thetas and context and outputs outcomes. + + To add: + save functionality + prior functionality + sigma functionality + + """ + + def __init__(self, + path, + simulator, + simulator_kwargs = None, + prior=None, + prior_kwargs = None, + simulation_dimensions = None, + ): + # self.path = path + # assert os.path.exists(path), f"Missing file: {path}" + # self.data = self._load(path) + super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) + + def _load(self, path): + assert path.split(".")[-1] == "h5", "File extension must be h5" + loaded_data = {} + with h5py.File(path, "r") as file: + for key, value in file.attrs.items(): + loaded_data[key] = value + for key in file.keys(): + loaded_data[key] = torch.Tensor(file[key][...]) + return loaded_data + + def _simulator_outcome(self): + """ + Get stored simulator outcomes at which the thetas are inferred. + Returns: + simulator outcomes array in the format y = (num_global, num_local, simulator_dimensions) + """ + try: + sim_outcome = self.data["simulator_outcome"] + num_global = self.data["num_global"] + num_local = self.data["num_local"] + sim_outcome = sim_outcome.reshape(num_global, num_local, -1) + return sim_outcome + except KeyError: + try: + sim_outcome = np.array((self.simulator_dimensions, len(self.thetas))) + for index, theta in enumerate(self.thetas): + sim_out = self.simulator(theta=theta.unsqueeze(0), n_samples=1) + sim_outcome[:, index] = sim_out + num_global = self.data["num_global"] + num_local = self.data["num_local"] + sim_outcome = sim_outcome.reshape(num_global, num_local, -1) + + return sim_outcome + + except Exception as e: + e = f"Data does not have a `simulator_output` field and could not generate it from a simulator: {e}" + raise ValueError(e) + + def _context(self): + """ Get stored context used to train the model. + Returns: + context array + """ + try: + return self.data["context"][0] # has to be made more general because in pendulum time is fixed. So we can do this + except KeyError: + raise NotImplementedError("Data does not have a `context` field.") + + def _thetas(self): + """ Get stored theta used to train the model. + + Returns: + theta array in the format y_local = (num_global, num_local, n_local) and y_global = (num_global, n_global) + + Raise: + NotImplementedError: Data does not have thetas. + """ + try: + thetas = self.data["thetas"] + num_global = self.data["num_global"] + num_local = self.data["num_local"] + n_local = self.data["n_local"] + n_global = self.data["n_global"] + y_local = thetas[:,0].reshape(num_global, num_local, n_local) + y_global = thetas[:,1].view(num_global, num_local, n_global) + y_global = y_global[:, 0, :] + return y_local, y_global + + except KeyError: + raise NotImplementedError("Data does not have a `thetas` field.") + diff --git a/src/deepdiagnostics/models/__init__.py b/src/deepdiagnostics/models/__init__.py index 29ee5f3..e8c3026 100644 --- a/src/deepdiagnostics/models/__init__.py +++ b/src/deepdiagnostics/models/__init__.py @@ -1,3 +1,4 @@ -from deepdiagnostics.models.sbi_model import SBIModel +from deepdiagnostics.models.sbi_model import SBIModel, HierarchyModel + +ModelModules = {"SBIModel": SBIModel, "HierarchyModel": HierarchyModel} -ModelModules = {"SBIModel": SBIModel} diff --git a/src/deepdiagnostics/models/model.py b/src/deepdiagnostics/models/model.py index 46e2c81..adde939 100644 --- a/src/deepdiagnostics/models/model.py +++ b/src/deepdiagnostics/models/model.py @@ -10,7 +10,7 @@ def __init__(self, model_path: str) -> None: def _load(self, path: str) -> None: return NotImplementedError - + def sample_posterior(self): return NotImplementedError diff --git a/src/deepdiagnostics/models/sbi_model.py b/src/deepdiagnostics/models/sbi_model.py index dedba51..f4a8f7f 100644 --- a/src/deepdiagnostics/models/sbi_model.py +++ b/src/deepdiagnostics/models/sbi_model.py @@ -1,8 +1,16 @@ import os import pickle +import sys + +# imports for Hierarchical model +import torch +from einops import rearrange, repeat from deepdiagnostics.models.model import Model +# import sys +# sys.path.append('/Users/jarugula/Research/Deepdiagnostics/DeepDiagnostics/hnpe_src') + class SBIModel(Model): """ @@ -54,3 +62,133 @@ def predict_posterior(self, data, context_samples): posterior_samples, context_samples ) return posterior_predictive_samples + + +class HierarchyModel(Model): + """ + Load the trained hierarchical model for deep sets saved as pickle file + pickle.dump(trained_model, file) + + The hierarchy deep sets model also needs to load its neural network components from the specified path. + + Args: + model_nn_path (str): Path to the neural network model file. + model_path (str): Path to the trained model pickle file. + """ + def __init__(self, model_path, model_nn_path): + # Load the model + self.model_nn_path = model_nn_path + super().__init__(model_path) + + def _load(self, path: str): + assert os.path.exists(path), f"Cannot find model file at location {path}" + assert os.path.exists(self.model_nn_path), f"Cannot find hierarchy deep sets model at location {self.model_nn_path}" + assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" + + if self.model_nn_path not in sys.path: + sys.path.insert(0, self.model_nn_path) + + with open(path, "rb") as file: + model = pickle.load(file) + + # self.model = model + print("Model loaded successfully.") + return model + + def _sample_global(self, x, n_samples=1000, device="cpu"): + """ + Sample from the global thetas using the deep set model eval. + + Args: + x (torch.Tensor): The output of the simulator at which the posteriors are evaluated as (number_of_global_samples, number_of_local_samples, simulator_dimensions) + n_samples (int): Number of samples to draw + device (str): Device to run the evaluation on + + Returns: + torch.Tensor: Samples from the global thetas + """ + deep_set = self.model.deep_set + deep_set.eval() + n_eval = x.shape[-2] + device = torch.device(device) + deep_set.to(device) + + x_flat = rearrange(x[:, :n_eval], "batch n_set l -> (batch n_set) l", n_set=n_eval) + x_enc = deep_set.enc.to(device)(x_flat.to(device)) + x_enc = rearrange(x_enc, "(batch n_set) n_out -> batch n_set n_out", n_set=n_eval) + + x_cond_global, _ = torch.chunk(x_enc, 2, -1) + x_cond_global = x_cond_global.mean(-2) + col_n_eval = torch.full((x_cond_global.size(0), 1), float(n_eval), dtype=torch.float32, device=device) + x_cond_global = torch.cat([x_cond_global, col_n_eval], dim=-1) + x_cond_global = deep_set.dec.to(device)(x_cond_global) + + samples = deep_set.flow_global.sample(num_samples=n_samples, context=x_cond_global) + # normalize shape to (batch, n_samples, dim) + if samples.dim() == 3 and samples.size(0) == n_samples: + samples = samples.permute(1, 0, 2).contiguous() + return samples + + def _sample_local(self, x, n_samples=1000, device="cpu"): + """ + Sample from the local context distribution using the deep set model eval. + + Args: + x (torch.Tensor): The output of the simulator at which the posteriors are evaluated as (number_of_global_samples, number_of_local_samples, simulator_dimensions) + n_samples (int): Number of samples to draw + device (str): Device to run the evaluation on + + Returns: + torch.Tensor: Samples from the local thetas + """ + deep_set = self.model.deep_set + deep_set.eval() + n_eval = x.shape[-2] + device = torch.device(device) + deep_set.to(device) + + # Encode observations + x_flat = rearrange(x[:, :n_eval], "batch n_set l -> (batch n_set) l", n_set=n_eval) + x_enc = deep_set.enc.to(device)(x_flat.to(device)) + x_enc = rearrange(x_enc, "(batch n_set) n_out -> batch n_set n_out", n_set=n_eval) + _, x_cond_local = torch.chunk(x_enc, 2, -1) # (batch, n_eval, ctx_dim_local) + + # Global mean context + g_samples = self._sample_global(x, n_samples=n_samples, device=str(device)) # pass x, not data + g_mean = g_samples.mean(dim=1) # (batch, n_global) + g_mean = repeat(g_mean, "batch p -> batch n_eval p", n_eval=n_eval) # (batch, n_eval, n_global) + + # Build local context and sample + ctx = torch.cat([x_cond_local, g_mean], dim=-1) # (batch, n_eval, ctx_dim) + batch, n_eval_eff, ctx_dim = ctx.shape + ctx_flat = ctx.reshape(batch * n_eval_eff, ctx_dim) # (batch*n_eval, ctx_dim) + + samples = deep_set.flow_local.sample(num_samples=n_samples, context=ctx_flat) + # Expected shape (n_samples, batch*n_eval, n_local) -> (batch, n_eval, n_samples, n_local) + if samples.dim() == 3 and samples.size(1) == batch * n_eval_eff: + samples = samples.permute(1, 0, 2).contiguous() + samples = samples.reshape(batch, n_eval_eff, n_samples, -1) + return samples + + def sample_posterior(self, n_samples, x_true, global_samples): + """ + Sample either the global or local thetas from the posterior distribution. + + Args: + n_samples (int): Number of samples to draw + x_true (torch.Tensor): The true observations + global_samples (bool): Whether to sample from the global or local distribution + + Returns: + torch.Tensor: Samples from the global or local thetas + + """ + if global_samples == True: + print("Evaluating global samples") + global_samples = self._sample_global(x_true, n_samples=n_samples) + return global_samples + elif global_samples == False: + print("Evaluating local samples") + local_samples = self._sample_local(x_true, n_samples=n_samples) + return local_samples + \ No newline at end of file diff --git a/src/deepdiagnostics/plots/__init__.py b/src/deepdiagnostics/plots/__init__.py index e83274c..157dab0 100644 --- a/src/deepdiagnostics/plots/__init__.py +++ b/src/deepdiagnostics/plots/__init__.py @@ -1,10 +1,10 @@ -from deepdiagnostics.plots.cdf_ranks import CDFRanks +from deepdiagnostics.plots.cdf_ranks import CDFRanks, HierarchyCDFRanks from deepdiagnostics.plots.coverage_fraction import CoverageFraction -from deepdiagnostics.plots.ranks import Ranks +from deepdiagnostics.plots.ranks import Ranks, HierarchyRanks from deepdiagnostics.plots.tarp import TARP from deepdiagnostics.plots.local_two_sample import LocalTwoSampleTest as LC2ST from deepdiagnostics.plots.predictive_posterior_check import PPC -from deepdiagnostics.plots.parity import Parity +from deepdiagnostics.plots.parity import Parity, HierarchyParity from deepdiagnostics.plots.predictive_prior_check import PriorPC diff --git a/src/deepdiagnostics/plots/cdf_ranks.py b/src/deepdiagnostics/plots/cdf_ranks.py index 99d1bb2..fac35f9 100644 --- a/src/deepdiagnostics/plots/cdf_ranks.py +++ b/src/deepdiagnostics/plots/cdf_ranks.py @@ -1,5 +1,7 @@ from sbi.analysis import sbc_rank_plot, run_sbc from torch import tensor +import torch +from tqdm import tqdm from typing import TYPE_CHECKING, Union from deepdiagnostics.utils.utils import DataDisplay @@ -77,3 +79,73 @@ def plot(self, data_display: Union[DataDisplay, str], **kwargs) -> tuple["fig", parameter_labels=self.parameter_names, colors=self.parameter_colors, ) + +class HierarchyCDFRanks(CDFRanks): + """ + Plots the ranks as a CDF plot for each theta parameter. + Option to choose between global and local ranks. + + .. code-block:: python + + from deepdiagnostics.plots import HierarchyCDFRanks + + HierarchyCDFRanks(model, data, global_samples=True, save=False, show=True)() + + """ + def __init__(self, model, data, global_samples: bool = True, **kwargs): + self.global_samples = bool(global_samples) + super().__init__(model, data, **kwargs) + + def plot_name(self, **kwargs) -> str: + gs = kwargs.pop("global_samples", self.global_samples) + if gs: + return "global_cdf_ranks.png" + else: + return "local_cdf_ranks.png" + + def _data_setup(self, **kwargs) -> DataDisplay: + gs = kwargs.pop("global_samples", self.global_samples) + gs = bool(gs) + + # support attribute or callable access + x = self.data.simulator_outcome() if callable(self.data.simulator_outcome) else self.data.simulator_outcome + thetas = self.data.thetas() if callable(self.data.thetas) else self.data.thetas + + # sample hierarchical posterior + posterior_samples = self.model.sample_posterior(self.samples_per_inference, x, global_samples=gs) + + n_params = posterior_samples.shape[-1] + + reduce_1d_fn = [eval(f"lambda theta, x: theta[:, {i}]") for i in range(n_params)] + + # pick global or local targets + y_true = thetas[1] if gs else thetas[0] + + # flatten if local samples such that [200, 25, 1000, 1] -> [200*25, 1000, 1] + if not gs: + posterior_samples = posterior_samples.reshape(-1, posterior_samples.shape[-2], posterior_samples.shape[-1]) + y_true = y_true.reshape(-1, y_true.shape[-1]) + + n_sbc_runs = posterior_samples.shape[0] + + # flatten x + x = x.reshape(-1, x.shape[-1]) + + ranks = torch.zeros((n_sbc_runs, len(reduce_1d_fn))) + + # calculate ranks + for sbc_idx, (y_true0, x0) in tqdm( + enumerate(zip(y_true, x, strict=False)), + total=n_sbc_runs, + desc=f"Calculating ranks for {n_sbc_runs} sbc samples.", + ): + for dim_idx, reduce_fn in enumerate(reduce_1d_fn): + # rank posterior samples against true parameter, reduced to 1D. + ranks[sbc_idx, dim_idx] = ( + (reduce_fn(posterior_samples[sbc_idx, :, :], x0) < reduce_fn(y_true0.unsqueeze(0), + x0)).sum().item() + ) + display_data = DataDisplay( + ranks=ranks + ) + return display_data \ No newline at end of file diff --git a/src/deepdiagnostics/plots/parity.py b/src/deepdiagnostics/plots/parity.py index 67d7a32..a267e19 100644 --- a/src/deepdiagnostics/plots/parity.py +++ b/src/deepdiagnostics/plots/parity.py @@ -3,7 +3,8 @@ from deepdiagnostics.utils.utils import DataDisplay import numpy as np - +# import for hirerarchical model +import torch from deepdiagnostics.plots.plot import Display @@ -62,6 +63,12 @@ def _data_setup(self, n_samples: int = 80, **kwargs) -> DataDisplay: true_samples[index] = self.data.thetas[sample, :] + # print shape of arrays + print(f"n_dims: {self.data.n_dims}") + print(f"true_samples shape: {true_samples.shape}") + print(f"posterior_sample_mean shape: {posterior_sample_mean.shape}") + print(f"posterior_sample_std shape: {posterior_sample_std.shape}") + return DataDisplay( n_dims=self.data.n_dims, true_samples=true_samples, @@ -119,7 +126,8 @@ def plot( figsize=(int(self.figure_size[0]*data_display.n_dims*.8), int(self.figure_size[1]*n_rows*.6)), height_ratios=height_ratios, sharex="col", - sharey=False) + sharey=False, + squeeze=False) figure.suptitle(title) figure.supxlabel(x_label) @@ -138,8 +146,9 @@ def plot( subplots[0, 0].set_ylabel("Parity") else: - parity_plot = subplots[theta_dimension] - subplots[0].set_ylabel("Parity") + print('theta_dimension', theta_dimension) + parity_plot = subplots[0, theta_dimension] + subplots[0, 0].set_ylabel("Parity") parity_plot.title.set_text(title) @@ -173,4 +182,82 @@ def plot( row_index += 1 - return figure, subplots \ No newline at end of file + return figure, subplots + +class HierarchyParity(Parity): + """ + Show plots directly comparing the posterior vs. true theta values. Make a plot that is (number of selected metrics) X dimensions of theta. + Includes the option to show differences, residual, and percent residual as plots under the main parity plot. + + Option to plot either the global or the local thetas of the Hierarchical model. + + .. code-block:: python + + from deepdiagnostics.plots import HierarchyParity + + HierarchyParity(model, data, global_samples = True, + show=True, save=False, run_id="090925", + parameter_names=["g"], include_residual = True # Add a plot showing the residual) + + """ + + def __init__(self, model, data, global_samples: bool = True, **kwargs): + # Parity.__init__ doesn't know global_samples; don't pass it up + self.global_samples = bool(global_samples) + super().__init__(model, data, **kwargs) + + + def plot_name(self, **kwargs) -> str: + gs = kwargs.pop("global_samples", self.global_samples) + if gs: + return "global_parity.png" + else: + return "local_parity.png" + + def _data_setup(self, n_samples: int = 1000, **kwargs) -> DataDisplay: + + gs = kwargs.pop("global_samples", self.global_samples) + gs = bool(gs) + # support attribute or callable access + x_true = self.data.simulator_outcome() if callable(self.data.simulator_outcome) else self.data.simulator_outcome + thetas = self.data.thetas() if callable(self.data.thetas) else self.data.thetas + + # sample hierarchical posterior + samples = self.model.sample_posterior(n_samples, x_true, global_samples=gs) + + # pick global or local targets + y_test = thetas[1] if gs else thetas[0] + + # sort by y_test to make parity plots monotonic in x + srt = torch.argsort(y_test, dim=0) + true_samples = torch.take_along_dim(y_test, srt, dim=0) + + # aggregate posterior along the sample dimension (-2) and align with sorting + posterior_sample_mean = torch.take_along_dim(samples.mean(dim=-2), srt, dim=0) + posterior_sample_std = torch.take_along_dim(samples.std(dim=-2), srt, dim=0) + + # if local, flatten samples to (n_samples* n_eval, n_local) + if not gs: + posterior_sample_mean = posterior_sample_mean.reshape(-1, posterior_sample_mean.shape[-1]) + posterior_sample_std = posterior_sample_std.reshape(-1, posterior_sample_std.shape[-1]) + true_samples = true_samples.reshape(-1, true_samples.shape[-1]) + + n_dims = int(y_test.shape[-1]) + + # convert to numpy for matplotlib + def to_np(t): + return t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else np.asarray(t) + + # print shape of arrays + # print(f"n_dims: {n_dims}") + # print(f"true_samples shape: {to_np(true_samples).shape}") + # print(f"posterior_sample_mean shape: {to_np(posterior_sample_mean).shape}") + # print(f"posterior_sample_std shape: {to_np(posterior_sample_std).shape}") + + return DataDisplay( + n_dims=n_dims, + true_samples=to_np(true_samples), + posterior_sample_mean=to_np(posterior_sample_mean), + posterior_sample_std=to_np(posterior_sample_std), + ) + \ No newline at end of file diff --git a/src/deepdiagnostics/plots/ranks.py b/src/deepdiagnostics/plots/ranks.py index 8ff86af..83b73c1 100644 --- a/src/deepdiagnostics/plots/ranks.py +++ b/src/deepdiagnostics/plots/ranks.py @@ -1,6 +1,8 @@ from typing import Union, TYPE_CHECKING from sbi.analysis import sbc_rank_plot, run_sbc from torch import tensor +import torch +from tqdm import tqdm from deepdiagnostics.plots.plot import Display from deepdiagnostics.utils.utils import DataDisplay @@ -69,3 +71,71 @@ def plot(self, data_display: Union[DataDisplay, dict], num_bins:int=20) -> tuple parameter_labels=self.parameter_names, colors=self.parameter_colors, ) + +class HierarchyRanks(Ranks): + """ + Plots the histogram of each theta parameter's rank. + + Option to choose between global and local ranks. + + .. code-block:: python + + from deepdiagnostics.plots import HierarchyRanks + + HierarchyRanks(model, data, global_samples=True, save=False, show=True)(num_bins=25) + """ + def __init__(self, model, data, global_samples: bool = True, **kwargs): + self.global_samples = bool(global_samples) + super().__init__(model, data, **kwargs) + + def plot_name(self, **kwargs) -> str: + gs = kwargs.pop("global_samples", self.global_samples) + if gs: + return "global_ranks.png" + else: + return "local_ranks.png" + + def _data_setup(self, **kwargs) -> DataDisplay: + gs = kwargs.pop("global_samples", self.global_samples) + gs = bool(gs) + + # support attribute or callable access + x = self.data.simulator_outcome() if callable(self.data.simulator_outcome) else self.data.simulator_outcome + thetas = self.data.thetas() if callable(self.data.thetas) else self.data.thetas + + # sample hierarchical posterior + posterior_samples = self.model.sample_posterior(self.samples_per_inference, x, global_samples=gs) + n_params = posterior_samples.shape[-1] + reduce_1d_fn = [eval(f"lambda theta, x: theta[:, {i}]") for i in range(n_params)] + + # pick global or local targets + y_true = thetas[1] if gs else thetas[0] + + # flatten if local samples such that [200, 25, 1000, 1] -> [200*25, 1000, 1] + if not gs: + posterior_samples = posterior_samples.reshape(-1, posterior_samples.shape[-2], posterior_samples.shape[-1]) + y_true = y_true.reshape(-1, y_true.shape[-1]) + + n_sbc_runs = posterior_samples.shape[0] + + # flatten x + x = x.reshape(-1, x.shape[-1]) + + ranks = torch.zeros((n_sbc_runs, len(reduce_1d_fn))) + + # calculate ranks + for sbc_idx, (y_true0, x0) in tqdm( + enumerate(zip(y_true, x, strict=False)), + total=n_sbc_runs, + desc=f"Calculating ranks for {n_sbc_runs} sbc samples.", + ): + for dim_idx, reduce_fn in enumerate(reduce_1d_fn): + # rank posterior samples against true parameter, reduced to 1D. + ranks[sbc_idx, dim_idx] = ( + (reduce_fn(posterior_samples[sbc_idx, :, :], x0) < reduce_fn(y_true0.unsqueeze(0), + x0)).sum().item() + ) + display_data = DataDisplay( + ranks=ranks + ) + return display_data diff --git a/src/deepdiagnostics/utils/defaults.py b/src/deepdiagnostics/utils/defaults.py index 2412a55..c830884 100644 --- a/src/deepdiagnostics/utils/defaults.py +++ b/src/deepdiagnostics/utils/defaults.py @@ -25,13 +25,16 @@ }, "plots": { "CDFRanks": {}, + "HierarchyCDFRanks": {}, "Ranks": {"num_bins": None}, + "HierarchyRanks": {"num_bins": None}, "CoverageFraction": {}, "TARP": { "coverage_sigma": 3 }, "LC2ST": {}, "Parity":{}, + "HierarchyParity": {}, "PPC": {}, "PriorPC":{} }, diff --git a/tests/conftest.py b/tests/conftest.py index ee4af93..c222ee5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,11 @@ import pytest import yaml import numpy as np +import sys -from deepdiagnostics.data import H5Data +from deepdiagnostics.data import H5Data, H5HierarchyData from deepdiagnostics.data.simulator import Simulator -from deepdiagnostics.models import SBIModel +from deepdiagnostics.models import SBIModel, HierarchyModel from deepdiagnostics.utils.config import get_item from deepdiagnostics.utils.simulator_utils import register_simulator @@ -75,17 +76,28 @@ def setUp(result_output): sim_paths = f"{simulator_config_path.strip('/')}/simulators.json" os.remove(sim_paths) - shutil.rmtree(result_output, ignore_errors=True) + # shutil.rmtree(result_output, ignore_errors=True) @pytest.fixture def model_path(): return "resources/savedmodels/sbi/sbi_linear_from_data.pkl" +@pytest.fixture +def hierarchy_model_path(): + return "resources/savedmodels/trained_model_0909.pkl" + +@pytest.fixture +def hierarchy_model_nn_path(): + return "resources/savedmodels/hnpe_src" @pytest.fixture def data_path(): return "resources/saveddata/data_test.h5" +@pytest.fixture +def hierarchy_data_path(): + return "resources/saveddata/Hierarchicalpendulum_data_test_40.h5" + @pytest.fixture def result_output(): path = "./temp_results/" @@ -102,11 +114,18 @@ def simulator_name(): def mock_model(model_path): return SBIModel(model_path) +@pytest.fixture +def mock_hierarchy_model(hierarchy_model_path, hierarchy_model_nn_path): + return HierarchyModel(hierarchy_model_path, hierarchy_model_nn_path) @pytest.fixture def mock_data(data_path, simulator_name): return H5Data(data_path, simulator_name) +@pytest.fixture +def mock_hierarchy_data(hierarchy_data_path, simulator_name): + return H5HierarchyData(hierarchy_data_path, simulator_name) + @pytest.fixture def mock_run_id(): return str(uuid.uuid4()).replace('-', '')[:10] @@ -116,11 +135,10 @@ def mock_2d_data(data_path): return H5Data(data_path, "Mock2DSimulator", simulation_dimensions=2) @pytest.fixture -def config_factory(result_output): +def config_factory(result_output, model_path, data_path, hierarchy_model_path, hierarchy_data_path): def factory( - model_path=None, + use_hierarchy: bool = False, model_engine=None, - data_path=None, data_engine=None, plot_2d=False, simulator=None, @@ -129,8 +147,8 @@ def factory( plots=None, metrics=None, ): - config = { - "common": {}, + cfg = { + "common": {"out_dir": result_output}, "model": {}, "data": {}, "plots_common": {}, @@ -139,44 +157,32 @@ def factory( "metrics": {}, } - # Single settings - config["common"]["out_dir"] = result_output - if model_path is not None: - config["model"]["model_path"] = model_path + mp = hierarchy_model_path if use_hierarchy else model_path + dp = hierarchy_data_path if use_hierarchy else data_path + cfg["model"]["model_path"] = str(mp) + cfg["data"]["data_path"] = str(dp) + if model_engine is not None: - config["model"]["model_engine"] = model_engine - if data_path is not None: - config["data"]["data_path"] = data_path + cfg["model"]["model_engine"] = model_engine if data_engine is not None: - config["data"]["data_engine"] = data_engine + cfg["data"]["data_engine"] = data_engine if simulator is not None: - config["data"]["simulator"] = simulator - if plot_2d: - config["data"]["simulator_dimensions"] = 2 - - # Dict settings - if plot_settings is not None: - for key, item in plot_settings.items(): - config["plots_common"][key] = item - if metrics_settings is not None: - for key, item in metrics_settings.items(): - config["metrics_common"][key] = item - + cfg["data"]["simulator"] = simulator + if plot_2d: + cfg["data"]["simulator_dimensions"] = 2 + + if plot_settings: + cfg["plots_common"].update(plot_settings) + if metrics_settings: + cfg["metrics_common"].update(metrics_settings) if metrics is not None: - if isinstance(metrics, dict): - config["metrics"] = metrics - if isinstance(metrics, list): - config["metrics"] = {metric: {} for metric in metrics} - + cfg["metrics"] = metrics if isinstance(metrics, dict) else {m: {} for m in metrics} if plots is not None: - if isinstance(plots, dict): - config["plots"] = plots - if isinstance(metrics, list): - config["plots"] = {plot: {} for plot in plots} + cfg["plots"] = plots if isinstance(plots, dict) else {p: {} for p in plots} temp_outpath = "./temp_config.yml" - yaml.dump(config, open(temp_outpath, "w")) - + with open(temp_outpath, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) return temp_outpath return factory diff --git a/tests/test_plots.py b/tests/test_plots.py index cfbacd7..9b566fc 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -5,16 +5,18 @@ from deepdiagnostics.plots import ( CDFRanks, + HierarchyCDFRanks, Ranks, + HierarchyRanks, CoverageFraction, TARP, LC2ST, PPC, PriorPC, - Parity + Parity, + HierarchyParity ) - @pytest.fixture def plot_config(config_factory): metrics_settings = { @@ -25,6 +27,15 @@ def plot_config(config_factory): config = config_factory(metrics_settings=metrics_settings) return config +@pytest.fixture +def plot_hierarchy_config(config_factory): + metrics_settings = { + "use_progress_bar": False, + "samples_per_inference": 10, + "percentiles": [95, 75, 50], + } + config = config_factory(metrics_settings=metrics_settings, use_hierarchy=True) + return config def test_plot_cdf(plot_config, mock_model, mock_data, mock_run_id): Config(plot_config) @@ -32,12 +43,36 @@ def test_plot_cdf(plot_config, mock_model, mock_data, mock_run_id): plot(**get_item("plots", "CDFRanks", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") +def test_plot_hierarchy_cdf(plot_hierarchy_config, mock_hierarchy_model, mock_hierarchy_data, mock_run_id): + Config(plot_hierarchy_config) + plot = HierarchyCDFRanks(mock_hierarchy_model, mock_hierarchy_data, global_samples=True, + run_id=mock_run_id, save=True, show=False, parameter_names=["g"]) + plot(**get_item("plots", "HierarchyCDFRanks", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot = HierarchyCDFRanks(mock_hierarchy_model, mock_hierarchy_data, global_samples=False, + run_id=mock_run_id, save=True, show=False, parameter_names=["l"]) + plot(**get_item("plots", "HierarchyCDFRanks", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + def test_plot_ranks(plot_config, mock_model, mock_data, mock_run_id): Config(plot_config) plot = Ranks(mock_model, mock_data, mock_run_id, save=True, show=False) plot(**get_item("plots", "Ranks", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") +def test_plot_hierarchy_ranks(plot_hierarchy_config, mock_hierarchy_model, mock_hierarchy_data, mock_run_id): + Config(plot_hierarchy_config) + plot = HierarchyRanks(mock_hierarchy_model, mock_hierarchy_data, global_samples=True, + run_id=mock_run_id, save=True, show=False, parameter_names=["g"]) + plot(**get_item("plots", "HierarchyRanks", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot = HierarchyRanks(mock_hierarchy_model, mock_hierarchy_data, global_samples=False, + run_id=mock_run_id, save=True, show=False, parameter_names=["l"]) + plot(**get_item("plots", "HierarchyRanks", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + def test_plot_coverage(plot_config, mock_model, mock_data, mock_run_id): Config(plot_config) plot = CoverageFraction(mock_model, mock_data, mock_run_id, save=True, show=False) @@ -125,6 +160,49 @@ def test_parity(plot_config, mock_model, mock_data, mock_run_id): assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") +def test_hierarchy_parity(plot_hierarchy_config, mock_hierarchy_model, mock_hierarchy_data, mock_run_id): + Config(plot_hierarchy_config) + plot = HierarchyParity(mock_hierarchy_model, mock_hierarchy_data, global_samples = True, + run_id=mock_run_id, save=True, show=False, parameter_names=["g"]) + + plot(include_difference= False, + include_residual = False, + include_percentage = False) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot(include_difference= True, + include_residual = True, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot(include_difference= True, + include_residual = True, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot = HierarchyParity(mock_hierarchy_model, mock_hierarchy_data, global_samples = False, + run_id=mock_run_id, save=True, show=False, parameter_names=["l"]) + + plot(include_difference= False, + include_residual = False, + include_percentage = False) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot(include_difference= True, + include_residual = True, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + plot(include_difference= True, + include_residual = True, + include_percentage = True) + + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") @pytest.mark.parametrize("plot_type", [CDFRanks, Ranks, CoverageFraction, TARP, PPC, PriorPC, Parity]) def test_rerun_plot(plot_type, plot_config, mock_model, mock_data, mock_run_id):