diff --git a/WORKSPACE b/WORKSPACE index 445f974b094333..e42663c6922986 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,6 +43,7 @@ python_init_repositories( "3.10": "//:requirements_lock_3_10.txt", "3.11": "//:requirements_lock_3_11.txt", "3.12": "//:requirements_lock_3_12.txt", + "3.13": "//:requirements_lock_3_13.txt", }, ) diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 0cfbaf22f820b1..1776d117e62996 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -44,7 +44,7 @@ nvidia-cusparse-cu12 == 12.5.1.3 nvidia-nccl-cu12 == 2.25.1 nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. -tensorflow-io-gcs-filesystem==0.37.1 +# tensorflow-io-gcs-filesystem==0.37.1 libclang >= 13.0.0 google_pasta ~= 0.2 flatbuffers ~= 24.3.25 diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt new file mode 100644 index 00000000000000..0ff58cbf5a3f54 --- /dev/null +++ b/requirements_lock_3_13.txt @@ -0,0 +1,744 @@ +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via + # dm-tree + # keras-nightly + # tb-nightly +astor==0.7.1 \ + --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ + --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e + # via -r ci/official/requirements_updater/requirements.in +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via -r ci/official/requirements_updater/requirements.in +attrs==25.1.0 \ + --hash=sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e \ + --hash=sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a + # via dm-tree +auditwheel==6.2.0 \ + --hash=sha256:4fc9f778cd81dac56820e8cdee9842dc44b8f435f8783606dabd4964d4638b30 \ + --hash=sha256:927e0fc9ab5b6040c1240c81dd7f50924c99c3ca876a776d52e042ba374f6604 + # via -r ci/official/requirements_updater/requirements.in +certifi==2025.1.31 \ + --hash=sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651 \ + --hash=sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe + # via requests +charset-normalizer==3.4.1 \ + --hash=sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537 \ + --hash=sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa \ + --hash=sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a \ + --hash=sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294 \ + --hash=sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b \ + --hash=sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd \ + --hash=sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601 \ + --hash=sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd \ + --hash=sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4 \ + --hash=sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d \ + --hash=sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2 \ + --hash=sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313 \ + --hash=sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd \ + --hash=sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa \ + --hash=sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8 \ + --hash=sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1 \ + --hash=sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2 \ + --hash=sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496 \ + --hash=sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d \ + --hash=sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b \ + --hash=sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e \ + --hash=sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a \ + --hash=sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4 \ + --hash=sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca \ + --hash=sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78 \ + --hash=sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408 \ + --hash=sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5 \ + --hash=sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3 \ + --hash=sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f \ + --hash=sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a \ + --hash=sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765 \ + --hash=sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6 \ + --hash=sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146 \ + --hash=sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6 \ + --hash=sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9 \ + --hash=sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd \ + --hash=sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c \ + --hash=sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f \ + --hash=sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545 \ + --hash=sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176 \ + --hash=sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770 \ + --hash=sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824 \ + --hash=sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f \ + --hash=sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf \ + --hash=sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487 \ + --hash=sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d \ + --hash=sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd \ + --hash=sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b \ + --hash=sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534 \ + --hash=sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f \ + --hash=sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b \ + --hash=sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9 \ + --hash=sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd \ + --hash=sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125 \ + --hash=sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9 \ + --hash=sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de \ + --hash=sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11 \ + --hash=sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d \ + --hash=sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35 \ + --hash=sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f \ + --hash=sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda \ + --hash=sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7 \ + --hash=sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a \ + --hash=sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971 \ + --hash=sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8 \ + --hash=sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41 \ + --hash=sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d \ + --hash=sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f \ + --hash=sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757 \ + --hash=sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a \ + --hash=sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886 \ + --hash=sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77 \ + --hash=sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76 \ + --hash=sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247 \ + --hash=sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85 \ + --hash=sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb \ + --hash=sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7 \ + --hash=sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e \ + --hash=sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6 \ + --hash=sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037 \ + --hash=sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1 \ + --hash=sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e \ + --hash=sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807 \ + --hash=sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407 \ + --hash=sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c \ + --hash=sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12 \ + --hash=sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3 \ + --hash=sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089 \ + --hash=sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd \ + --hash=sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e \ + --hash=sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00 \ + --hash=sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616 + # via requests +dill==0.3.7 \ + --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ + --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 + # via -r ci/official/requirements_updater/requirements.in +dm-tree==0.1.9 \ + --hash=sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf \ + --hash=sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2 \ + --hash=sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9 \ + --hash=sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7 \ + --hash=sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15 \ + --hash=sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89 \ + --hash=sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15 \ + --hash=sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc \ + --hash=sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc \ + --hash=sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b \ + --hash=sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78 \ + --hash=sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c \ + --hash=sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c \ + --hash=sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8 \ + --hash=sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607 + # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in +gast==0.4.0 \ + --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ + --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in +grpcio==1.70.0 \ + --hash=sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655 \ + --hash=sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a \ + --hash=sha256:0a5c78d5198a1f0aa60006cd6eb1c912b4a1520b6a3968e677dbcba215fabb40 \ + --hash=sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7 \ + --hash=sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27 \ + --hash=sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1 \ + --hash=sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a \ + --hash=sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597 \ + --hash=sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839 \ + --hash=sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199 \ + --hash=sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5 \ + --hash=sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb \ + --hash=sha256:4119fed8abb7ff6c32e3d2255301e59c316c22d31ab812b3fbcbaf3d0d87cc68 \ + --hash=sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e \ + --hash=sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0 \ + --hash=sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2 \ + --hash=sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d \ + --hash=sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3 \ + --hash=sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b \ + --hash=sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873 \ + --hash=sha256:7c73c42102e4a5ec76608d9b60227d917cea46dff4d11d372f64cbeb56d259d0 \ + --hash=sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d \ + --hash=sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab \ + --hash=sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e \ + --hash=sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56 \ + --hash=sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851 \ + --hash=sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898 \ + --hash=sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c \ + --hash=sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40 \ + --hash=sha256:aa573896aeb7d7ce10b1fa425ba263e8dddd83d71530d1322fd3a16f31257b4a \ + --hash=sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c \ + --hash=sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f \ + --hash=sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c \ + --hash=sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f \ + --hash=sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113 \ + --hash=sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4 \ + --hash=sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f \ + --hash=sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77 \ + --hash=sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4 \ + --hash=sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528 \ + --hash=sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295 \ + --hash=sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca \ + --hash=sha256:d405b005018fd516c9ac529f4b4122342f60ec1cee181788249372524e6db429 \ + --hash=sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd \ + --hash=sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386 \ + --hash=sha256:dfa089a734f24ee5f6880c83d043e4f46bf812fcea5181dcb3a572db1e79e01c \ + --hash=sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1 \ + --hash=sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea \ + --hash=sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf \ + --hash=sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff \ + --hash=sha256:f19375f0300b96c0117aca118d400e76fede6db6e91f3c34b7b035822e06c35f \ + --hash=sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f \ + --hash=sha256:f32090238b720eb585248654db8e3afc87b48d26ac423c8dde8334a232ff53c9 \ + --hash=sha256:fe9dbd916df3b60e865258a8c72ac98f3ac9e2a9542dcb72b7a34d236242a5ce \ + --hash=sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly +h5py==3.13.0 \ + --hash=sha256:10894c55d46df502d82a7a4ed38f9c3fdbcb93efb42e25d275193e093071fade \ + --hash=sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3 \ + --hash=sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec \ + --hash=sha256:22ffe2a25770a2d67213a1b94f58006c14dce06933a42d2aaa0318c5868d1508 \ + --hash=sha256:337af114616f3656da0c83b68fcf53ecd9ce9989a700b0883a6e7c483c3235d4 \ + --hash=sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca \ + --hash=sha256:477c58307b6b9a2509c59c57811afb9f598aedede24a67da808262dfa0ee37b4 \ + --hash=sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a \ + --hash=sha256:5540daee2b236d9569c950b417f13fd112d51d78b4c43012de05774908dff3f5 \ + --hash=sha256:560e71220dc92dfa254b10a4dcb12d56b574d2d87e095db20466b32a93fec3f9 \ + --hash=sha256:56dd172d862e850823c4af02dc4ddbc308f042b85472ffdaca67f1598dff4a57 \ + --hash=sha256:57c4c74f627c616f02b7aec608a8c706fe08cb5b0ba7c08555a4eb1dde20805a \ + --hash=sha256:782ff0ac39f455f21fd1c8ebc007328f65f43d56718a89327eec76677ebf238a \ + --hash=sha256:82690e89c72b85addf4fc4d5058fb1e387b6c14eb063b0b879bf3f42c3b93c35 \ + --hash=sha256:851ae3a8563d87a5a0dc49c2e2529c75b8842582ccaefbf84297d2cfceeacd61 \ + --hash=sha256:8a8e38ef4ceb969f832cc230c0cf808c613cc47e31e768fd7b1106c55afa1cb8 \ + --hash=sha256:9c82ece71ed1c2b807b6628e3933bc6eae57ea21dac207dca3470e3ceaaf437c \ + --hash=sha256:be949b46b7388074c5acae017fbbe3e5ba303fd9daaa52157fdfef30bbdacadd \ + --hash=sha256:c10f061764d8dce0a9592ce08bfd5f243a00703325c388f1086037e5d619c5f1 \ + --hash=sha256:d2cf6a231a07c14acd504a945a6e9ec115e0007f675bde5e0de30a4dc8d86a31 \ + --hash=sha256:d571644958c5e19a61c793d8d23cd02479572da828e333498c9acc463f4a3997 \ + --hash=sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d \ + --hash=sha256:e520ec76de00943dd017c8ea3f354fa1d2f542eac994811943a8faedf2a7d5cb \ + --hash=sha256:e79d8368cd9295045956bfb436656bea3f915beaa11d342e9f79f129f5178763 \ + --hash=sha256:f35640e81b03c02a88b8bf99fb6a9d3023cc52f7c627694db2f379e0028f2868 \ + --hash=sha256:fb267ce4b83f9c42560e9ff4d30f60f7ae492eacf9c7ede849edf8c1b860e16b + # via + # -r ci/official/requirements_updater/requirements.in + # keras-nightly +idna==3.10 \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 + # via requests +jax==0.4.7 \ + --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 + # via -r ci/official/requirements_updater/requirements.in +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b + # via -r ci/official/requirements_updater/requirements.in +markdown==3.7 \ + --hash=sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2 \ + --hash=sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803 + # via tb-nightly +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +markupsafe==3.0.2 \ + --hash=sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4 \ + --hash=sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30 \ + --hash=sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0 \ + --hash=sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9 \ + --hash=sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396 \ + --hash=sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13 \ + --hash=sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028 \ + --hash=sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca \ + --hash=sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557 \ + --hash=sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832 \ + --hash=sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0 \ + --hash=sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b \ + --hash=sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579 \ + --hash=sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a \ + --hash=sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c \ + --hash=sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff \ + --hash=sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c \ + --hash=sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22 \ + --hash=sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094 \ + --hash=sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb \ + --hash=sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e \ + --hash=sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5 \ + --hash=sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a \ + --hash=sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d \ + --hash=sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a \ + --hash=sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b \ + --hash=sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8 \ + --hash=sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225 \ + --hash=sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c \ + --hash=sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144 \ + --hash=sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f \ + --hash=sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87 \ + --hash=sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d \ + --hash=sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93 \ + --hash=sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf \ + --hash=sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158 \ + --hash=sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84 \ + --hash=sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb \ + --hash=sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48 \ + --hash=sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171 \ + --hash=sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c \ + --hash=sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6 \ + --hash=sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd \ + --hash=sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d \ + --hash=sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1 \ + --hash=sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d \ + --hash=sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca \ + --hash=sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a \ + --hash=sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29 \ + --hash=sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe \ + --hash=sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798 \ + --hash=sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c \ + --hash=sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8 \ + --hash=sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f \ + --hash=sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f \ + --hash=sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a \ + --hash=sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178 \ + --hash=sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0 \ + --hash=sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79 \ + --hash=sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430 \ + --hash=sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50 + # via werkzeug +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 + # via + # -r ci/official/requirements_updater/requirements.in + # jax + # keras-nightly +namex==0.0.8 \ + --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ + --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 + # via keras-nightly +numpy==2.1.3 \ + --hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \ + --hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \ + --hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \ + --hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \ + --hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \ + --hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \ + --hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \ + --hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \ + --hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \ + --hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \ + --hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \ + --hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \ + --hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \ + --hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \ + --hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \ + --hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \ + --hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \ + --hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \ + --hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \ + --hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \ + --hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \ + --hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \ + --hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \ + --hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \ + --hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \ + --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ + --hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \ + --hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \ + --hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \ + --hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \ + --hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \ + --hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \ + --hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \ + --hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \ + --hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \ + --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ + --hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \ + --hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \ + --hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \ + --hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \ + --hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \ + --hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \ + --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \ + --hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \ + --hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \ + --hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \ + --hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \ + --hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \ + --hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \ + --hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \ + --hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \ + --hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \ + --hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \ + --hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \ + --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree + # h5py + # jax + # keras-nightly + # ml-dtypes + # opt-einsum + # scipy + # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via + # -r ci/official/requirements_updater/requirements.in + # jax +packaging==23.2 \ + --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ + --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 + # via + # -r ci/official/requirements_updater/requirements.in + # auditwheel + # tb-nightly +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r ci/official/requirements_updater/requirements.in +protobuf==5.29.3 \ + --hash=sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f \ + --hash=sha256:0eb32bfa5219fc8d4111803e9a690658aa2e6366384fd0851064b963b6d1f2a7 \ + --hash=sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888 \ + --hash=sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620 \ + --hash=sha256:6ce8cc3389a20693bfde6c6562e03474c40851b44975c9b2bf6df7d8c4f864da \ + --hash=sha256:84a57163a0ccef3f96e4b6a20516cedcf5bb3a95a657131c5c3ac62200d23252 \ + --hash=sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a \ + --hash=sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e \ + --hash=sha256:b89c115d877892a512f79a8114564fb435943b59067615894c3b13cd3e1fa107 \ + --hash=sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f \ + --hash=sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84 + # via tb-nightly +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.1 \ + --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ + --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c + # via rich +requests==2.32.3 \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r ci/official/requirements_updater/requirements.in +rich==13.9.4 \ + --hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \ + --hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90 + # via keras-nightly +scipy==1.15.2 \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r ci/official/requirements_updater/requirements.in + # jax +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # tb-nightly +tb-nightly==2.19.0a20250218 \ + --hash=sha256:7c7fea911a9e113e7d40fa9aed96168840e2443c5ada52fba5bc3645ec6e206f + # via -r ci/official/requirements_updater/requirements.in +tblib==2.0.0 \ + --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ + --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 + # via -r ci/official/requirements_updater/requirements.in +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tb-nightly +termcolor==2.3.0 \ + --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ + --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a + # via -r ci/official/requirements_updater/requirements.in +typing-extensions==4.8.0 \ + --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ + --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef + # via -r ci/official/requirements_updater/requirements.in +urllib3==2.3.0 \ + --hash=sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df \ + --hash=sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d + # via requests +werkzeug==3.1.3 \ + --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ + --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 + # via tb-nightly +wheel==0.41.3 \ + --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ + --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 + # via + # -r ci/official/requirements_updater/requirements.in + # astunparse +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree + +# The following packages are considered to be unsafe in a requirements file: +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD index 44b35ae4c58604..f572ad6418feba 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -37,6 +38,62 @@ cc_library( ], ) +cc_library( + name = "quantization_lib", + srcs = [ + "quantization_driver.cc", + "quantization_interface.cc.inc", + "quantization_utils.cc", + ], + hdrs = [ + "quantization_driver.h", + "quantization_interface.h.inc", + "quantization_traits.h", + "quantization_utils.h", + ], + deps = [ + ":quantization_config", + ":quantization_interfaces_inc_gen", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "quantization_driver_test", + srcs = ["quantization_driver_test.cc"], + deps = [ + ":quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + td_library( name = "quantization_td_files", srcs = [ @@ -87,5 +144,7 @@ cc_library( ) exports_files([ + "quantization_traits.h", "quantization_config.h", + "quantization_utils.h", ]) diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td index 0f9b6a74762f9b..690fe4be1d46eb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type()">, + CPred<"$_self.cast()" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"$_self.cast()" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"$_self.cast()" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; @@ -73,6 +73,8 @@ def QI32 : QuantizedType<"Uniform", [32], 1>; def FixedOutputRangeInterface : OpInterface< "FixedOutputRangeInterface"> { + let cppNamespace = "TFL"; + let description = [{ Interface for defining the fixed output range. }]; @@ -88,6 +90,8 @@ def FixedOutputRangeInterface : OpInterface< def AffineQuantizedOpInterface : OpInterface< "AffineQuantizedOpInterface"> { + let cppNamespace = "TFL"; + let description = [{ Interface for affine quantized ops (conv2d, fully_connected, etc.) }]; @@ -113,6 +117,8 @@ def AffineQuantizedOpInterface : OpInterface< } def SameOperandsAndResultsScale : OpInterface<"SameScalesOpInterface"> { + let cppNamespace = "TFL"; + let description = [{ Interface for ops potentially have same operands and results scales. }]; @@ -131,12 +137,14 @@ def SameOperandsAndResultsScale : OpInterface<"SameScalesOpInterface"> { ]; let verify = [{ - return quant::VerifySameScales($_op); + return TFL::VerifySameScales($_op); }]; } def DynamicRangeQuantizedOpInterface : OpInterface< "DynamicRangeQuantizedOpInterface"> { + let cppNamespace = "TFL"; + let description = [{ Interface for ops dynamic range quantization is supported. @@ -196,12 +204,12 @@ def DynamicRangeQuantizedOpInterface : OpInterface< // Specify this trait if the op has a fixed output value range. class FixedResultScale : NativeOpTrait::Impl")>; + "TFL::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>; // Specify this trait if the bias-th input of the op is a bias input, which // needs a scale based on the scales of op1 and op2. class AccumulatorUniformScale : NativeOpTrait< - !strconcat("quant::AccumulatorUniformScale<", + !strconcat("TFL::AccumulatorUniformScale<", !interleave([bias, op1, op2], ", "), ">::Impl")>; @@ -209,11 +217,11 @@ class AccumulatorUniformScale : NativeOpTrait< // and also the quantization dimension if per-axis quantization is support. // If the quantization dimension is -1, per-axis quantization isn't supported. class AffineOpCoefficient : NativeOpTrait< - !strconcat("quant::AffineOpCoefficient<", + !strconcat("TFL::AffineOpCoefficient<", !interleave([dim, index], ", "), ">::Impl")>; // Specify this trait if the op does have quantizable output. Quantizers will // apply quantization on this op. -def QuantizableResult : NativeOpTrait<"quant::QuantizableResult">; +def QuantizableResult : NativeOpTrait<"TFL::QuantizableResult">; #endif // TF_Quantization diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc index daa787589ea7b1..9aef4058cf960d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc @@ -121,7 +121,8 @@ bool ParseInputNodeQuantSpecs(const absl::string_view node_names, } tensorflow::DataType final_type = tensorflow::DT_FLOAT; - if (!inference_type.empty() && !DataType_Parse(inference_type, &final_type)) { + if (!inference_type.empty() && + !DataType_Parse(std::string(inference_type), &final_type)) { return true; } return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type, diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc new file mode 100644 index 00000000000000..0ce7f43cd24fe7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc @@ -0,0 +1,958 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" + +namespace mlir { +namespace TFL { +namespace { + +constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; + +// Uses the type of `value` to set the initial state of the index-th result if +// `as_result` is true or index-th operand if `as_result` is false. The state +// is immutable if the type is a quantized type. Returns the index of this +// new state in the state vector. +void InitializeStateForValue( + Operation* op, const int index, const Value value, const bool as_result, + std::vector& states, + DenseMap& value_to_state, + DenseMap& operand_states, + DenseMap& result_states) { + const auto [cached, inserted] = value_to_state.try_emplace(value, 0); + if (!inserted) { + if (as_result) { + result_states[{op, index}] = cached->second; + } else { + operand_states[{op, index}] = cached->second; + } + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(value.getType()); + + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states.size(); + states.push_back({quantized_type, immutable}); + if (as_result) { + result_states[{op, index}] = next_state_index; + } else { + operand_states[{op, index}] = next_state_index; + } + + cached->second = next_state_index; +} + +bool HasPerAxisQuantizedOperand(Operation* op) { + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto dq_op = dyn_cast_or_null( + op->getOperand(i).getDefiningOp())) { + auto type = + mlir::cast(dq_op.getArg().getType()).getElementType(); + if (auto per_axis_qtype = + mlir::dyn_cast_or_null( + QuantizedType::getQuantizedElementType(type))) { + return true; + } + } + } + return false; +} + +} // namespace + +void QuantizationDriver::InitializeArgState(const BlockArgument arg, + const Value arg_value) { + const auto [cached, inserted] = value_to_state_.try_emplace(arg_value, 0); + if (!inserted) { + arg_states_[arg] = cached->second; + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(arg_value.getType()); + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states_.size(); + states_.push_back({quantized_type, immutable}); + arg_states_[arg] = next_state_index; + cached->second = next_state_index; +} + +void QuantizationDriver::InitializeOperandState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/false, states_, + value_to_state_, operand_states_, result_states_); +} + +void QuantizationDriver::InitializeResultState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/true, states_, + value_to_state_, operand_states_, result_states_); +} + +std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { + return op_quant_spec_getter_(op); +} + +std::unique_ptr QuantizationDriver::GetQuantScaleSpec( + Operation* op) { + return op_quant_scale_spec_getter_(op); +} + +bool QuantizationDriver::IsQuantized(Operation* op) { + for (int i = 0; i < op->getNumResults(); ++i) { + if (GetResultQuantState(op, i).IsEmpty()) return false; + } + return true; +} + +bool QuantizationDriver::SetConstantResultParams(Operation* op) { + DenseFPElementsAttr attr; + const Value result = op->getResult(0); + if (!matchPattern(result, m_Constant(&attr))) { + return false; + } + // TODO: b/323478683 - Make storage_type_width and narrow_range configurable. + Type final_type; + const auto it = optimized_weights_.find(op); + const bool is_weight = it != optimized_weights_.end(); + const bool is_weight_with_per_channel_support = + is_weight && it->second != -1 && is_signed_; + + if (is_weight_with_per_channel_support && !disable_per_channel_) { + // When `disable_per_channel_` is false, per-channel symmetric quantization + // parameters are created from the weights when the ops support per-channel + // quantization. Otherwise, uses per-tensor asymmetric quantization with + // narrow range. + + // per-axis quantization weight, with symmetric min/max enforced. + final_type = GetUniformQuantizedPerAxisTypeForWeight( + attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_, + /*narrow_range=*/true, legacy_float_scale_); + } else { + // per-tensor quantization weight + final_type = GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/is_weight && is_signed_, + /*num_bits=*/8, is_signed_, + /*narrow_range=*/is_weight, legacy_float_scale_); + } + if (const auto quant_type = mlir::dyn_cast_or_null(final_type); + quant_type != nullptr) { + return SetResultParams(op, /*result_index=*/0, quant_type); + } + return false; +} + +bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, + const QuantizedType quantized_type) { + QuantState& state = GetResultQuantState(op, result_index); + if (state.params == quantized_type) { + return false; + } + if (!state.IsEmpty()) { + RequantizeStates& rescales = GetResultRequantizeStates(op, result_index); + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_INPUT; + rescale.params = quantized_type; + return true; + } + state.params = quantized_type; + AddUserToList(op, result_index); + return true; +} + +QuantizedType QuantizationDriver::GetBiasParams( + Operation* op, const int bias_index, + const ArrayRef non_bias_operand_indices, + const AccumulatorScaleFunc func) { + QuantState& bias_state = GetOperandQuantState(op, bias_index); + if (!bias_state.IsEmpty()) { + return bias_state.params; + } + std::vector op_types{}; + op_types.reserve(non_bias_operand_indices.size()); + + int adjusted_quant_dim = -1; + if (op->getNumOperands() > bias_index) { + // Some kernels allow 1D bias, broadcasting it inside the kernel. In this + // case, the `quantizedDimension=0` when quantizing per-channel. + // However, for some kernels which require bias to be already broadcasted + // to match the accumulation shape, the very last index should be used. + Operation* bias_op = op->getOperand(bias_index).getDefiningOp(); + if (bias_op != nullptr) { + Type bias_type = bias_op->getResult(0).getType(); + if (bias_type != builder_.getNoneType()) { + const int bias_rank = mlir::dyn_cast(bias_type).getRank(); + adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; + } + } + } + + for (const int non_bias_operand_index : non_bias_operand_indices) { + const QuantState& non_bias_state = + GetOperandQuantState(op, non_bias_operand_index); + op_types.push_back(non_bias_state.params); + } + return func(op_types, adjusted_quant_dim, legacy_float_scale_); +} + +bool QuantizationDriver::SetOperandParams(Operation* op, + const int operand_index, + const QuantizedType quantized_type, + const bool override) { + QuantState& state = GetOperandQuantState(op, operand_index); + if (state.params == quantized_type) { + return false; + } + + if (!state.IsEmpty() && !override) { + RequantizeStates& rescales = GetOperandRequantizeStates(op, operand_index); + for (RequantizeState& rescale : rescales) { + if (rescale.params == quantized_type) { + rescale.users.emplace_back(op, operand_index); + return true; + } + } + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_OUTPUT; + rescale.params = quantized_type; + rescale.users.emplace_back(op, operand_index); + return true; + } + + state.params = quantized_type; + AddOperandToList(op, operand_index); + return true; +} + +void QuantizationDriver::QuantizeOpResult(Operation* op, const int result_index, + const QuantizedType quantized_type) { + builder_.setInsertionPointAfter(op); + const Value original_result = op->getResult(result_index); + QuantizeValue(original_result, quantized_type, op->getLoc()); +} + +void QuantizationDriver::QuantizeArg(BlockArgument arg, + const QuantizedType quantized_type) { + builder_.setInsertionPointToStart(arg.getOwner()); + QuantizeValue(arg, quantized_type, builder_.getUnknownLoc()); +} + +void QuantizationDriver::QuantizeValue(Value value, + QuantizedType quantized_type, + const Location loc) { + const Type expressed_type = value.getType(); + const Type new_value_type = + quantized_type.castFromExpressedType(expressed_type); + // Skip if `value` or `value`'s element type doesn't match the expressed type + // of `quantized_type`. + if (new_value_type == nullptr) return; + + auto quantize = + builder_.create(loc, new_value_type, value); + auto dequantize = builder_.create( + loc, expressed_type, quantize.getResult()); + + // This attribute is set to distinguish the quantize ops being added by the + // quantization pass. These ops can be removed without losing original + // program accuracy. + // TODO: b/323478683 - Make the attribute being part of op definition. + quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); + + // `original_result` has a use to `quantize`, so this will replace that use + // by the result of `dequantize`. Remember to reset that use afterwards + value.replaceAllUsesWith(dequantize); + quantize.getOperation()->replaceUsesOfWith(dequantize, value); +} + +void QuantizationDriver::RequantizeOpResult(Operation* op, + const int result_index, + RequantizeStates& states) { + if (states.empty()) return; + + builder_.setInsertionPointAfter(op); + Value value = op->getResult(result_index); + RequantizeState::RequantizePosition pos = states.front().pos; + if (pos == RequantizeState::NO_REQUANTIZE) { + return; + } + for (const RequantizeState& state : states) { + // Check that all requantization positions are the same for each state. + // Unsure if this check is required. + if (state.pos != pos) { + return; + } + } + if (pos == RequantizeState::ON_OUTPUT) { + Operation* user = value.getUses().begin().getUser(); + if (isa(user)) { + // The requantize op is inserted between `quantize` and `dequantize` ops. + value = user->getResult(0); + builder_.setInsertionPointAfter(user); + } + } + RequantizeValue(value, states, op->getLoc()); +} + +void QuantizationDriver::RequantizeArg(const BlockArgument arg, + RequantizeStates& states) { + Value value = arg; + builder_.setInsertionPointToStart(arg.getOwner()); + if (value.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); + } + } + RequantizeValue(value, states, builder_.getUnknownLoc()); +} + +void QuantizationDriver::RequantizeValue(Value value, RequantizeStates& states, + const Location loc) { + if (states.empty() || states.front().pos == RequantizeState::NO_REQUANTIZE) { + return; + } + if (states.front().pos == RequantizeState::ON_INPUT) { + RequantizeState& state = states.front(); + const Type expressed_type = value.getType(); + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + if (!new_type) return; + auto requantize_op = + builder_.create(loc, new_type, value); + value.replaceAllUsesWith(requantize_op); + requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value); + // This requantization was defined as required for the result value, so + // there should be only one requant state. + return; + } + + // If this is an operand that requires requantization, then the value should + // only have one `DequantizeCastOp` user which produces the operand value. + if (!value.hasOneUse()) { + return; + } + auto dequant_op = dyn_cast_or_null( + value.use_begin().getUser()); + if (!dequant_op) { + return; + } + // It is possible that the dequant value is used by a op that doesn't require + // requant, so only overwrite the first if that is not the case. + const int num_uses = std::distance(dequant_op.getResult().use_begin(), + dequant_op.getResult().use_end()); + + // Whether to replace quantization params of the first dequantize op + // after the quantized value is produced. + // If there is a use other than the requantize states, then we can't clobber. + bool clobber_first = num_uses <= states.size(); + for (RequantizeState& state : states) { + Type expressed_type = QuantizedType::castToExpressedType(value.getType()); + if (!expressed_type) continue; + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + // This value isn't an expressed type (float), skip. + if (!new_type) continue; + + auto requantize_op = + builder_.create(loc, new_type, value); + + if (clobber_first) { + dequant_op.setOperand(requantize_op.getResult()); + // All ops requiring this value already use the result of dequant. + clobber_first = false; + } else { + auto new_dequant_op = builder_.create( + loc, dequant_op.getResult().getType(), requantize_op.getResult()); + for (auto [op, operand_idx] : state.users) { + op->setOperand(operand_idx, new_dequant_op.getResult()); + } + } + } +} + +// A heuristic to get quantization parameters satisfies the same scale +// constraints: +// - If there are immutable states, +// - use the single input, or, +// - use the single output, or, +// - use the first one in the collection, +// - use the single input if it is ready, or, +// - use the single output if it is ready, or, +// - use the first ready one in the collection. +QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( + Operation* op) { + // Two vector to collect Non-empty operands and results states. + std::vector mutable_states, immutable_states; + for (int i = 0; i < op->getNumOperands(); ++i) { + QuantState& state = GetOperandQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_operands_num = immutable_states.size(); + const int mutable_operands_num = mutable_states.size(); + // Use the operand's state if it is immutable and it is the only one + // operand. + if (op->getNumOperands() == 1 && immutable_operands_num == 1) { + return immutable_states.front()->params; + } + + for (int i = 0; i < op->getNumResults(); ++i) { + QuantState& state = GetResultQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_results_num = + immutable_states.size() - immutable_operands_num; + const int mutable_results_num = mutable_states.size() - mutable_operands_num; + // Use the result's state if it is immutable and it is the only one result. + if (op->getNumResults() == 1 && immutable_results_num == 1) { + return immutable_states.back()->params; + } + + // Use the first immutable state to quantize the rest operands and results. + if (!immutable_states.empty()) return immutable_states.front()->params; + + // If there are no immutable states, use the operand's state if it is the + // only one operand and has parameters propagated. + if (op->getNumOperands() == 1 && mutable_operands_num == 1) { + return mutable_states.front()->params; + } + + // If there are no immutable states, use the result's state if it is the + // only one result and has parameters propagated. + if (op->getNumResults() == 1 && mutable_results_num == 1) { + return mutable_states.back()->params; + } + + // Use the first propagated state to quantize the rest operands and results. + if (!mutable_states.empty()) return mutable_states.front()->params; + + // None operands/results have parameters propagated, skip this node for now. + return {}; +} + +void QuantizationDriver::PreprocessConstantOps() { + fn_.walk([&](arith::ConstantOp cst) { + // Non-float tensors are neither weights nor require quantization. + const auto type = mlir::dyn_cast(cst.getType()); + if (!type || !mlir::isa(type.getElementType())) return; + + // Skip if the value is NaN or INF. + // Otherwise the illegal scale/zp will be calculated. + auto float_attr = mlir::dyn_cast(cst.getValueAttr()); + if (float_attr && (float_attr.getValues().empty() || + !float_attr.getValues()[0].isFinite())) { + return; + } + + const Value value = cst.getResult(); + builder_.setInsertionPoint(cst); + + // The following loop will change the value uses, thus we cache all the uses + // needs to be changed. + SmallVector> uses; + for (OpOperand& use : value.getUses()) { + uses.push_back({use.getOwner(), use.getOperandNumber()}); + } + for (const auto [user, operand_num] : uses) { + const std::unique_ptr spec = GetQuantSpec(user); + const std::unique_ptr scale_spec = + GetQuantScaleSpec(user); + const BiasParamsMap biases = spec->biases_params; + + // The quantization parameters of a `weight` shouldn't be determined by + // other values. So any constants which are not bias, an operand of an + // op with same scale requirements, and haven't been quantized are + // weights. + if (!biases.contains(operand_num) && + !scale_spec->has_same_scale_requirement && + !dyn_cast(user)) { + // Needs to scan the content of weights to get the quantization + // parameters if there are no quantization parameters (FakeQuant ops). + // For this case, the weight will not be duplicated. + weights_.insert(cst); + if (spec->coeff_op_quant_dim.find(operand_num) != + spec->coeff_op_quant_dim.end()) { + optimized_weights_.insert( + {cst, spec->coeff_op_quant_dim[operand_num]}); + } + } else { + // This is a bias or an operand of an op with same scale requirements, + // so the quantization parameter are propagated from or determined by + // other values. Duplicate this constant in case it is shared by + // different users. + if (uses.size() > 1) { + auto new_constant_op = + builder_.create(cst.getLoc(), cst.getValue()); + user->setOperand(operand_num, new_constant_op); + } + } + } + }); +} + +void QuantizationDriver::SetupAllStates() { + for (BlockArgument arg : fn_.getArguments()) { + args_.push_back(arg); + Value value = arg; + // If the argument is quantized, it should only has one user. + if (arg.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + } + } + InitializeArgState(arg, value); + } + + fn_.walk([&](Operation* op) { + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + if (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) { + return; + } + work_list_.push_back(op); + + for (int i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (Operation* inst = operand.getDefiningOp()) { + // If the operand comes from a `quantfork::DequantizeCastOp`, we use + // the quantized input of this `quantfork::DequantizeCastOp` to set the + // state. + if (auto dq = dyn_cast(inst)) { + operand = dq.getArg(); + } + } + InitializeOperandState(op, i, operand); + } + + for (int i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + // If the result has been quantized, it should only be used by a + // `quantfork::QuantizeCastOp`. For this case, we uses the quantized + // result to create the state and mark it immutable. + if (result.hasOneUse()) { + Operation* user = result.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + result = q.getResult(); + } + } + InitializeResultState(op, i, result); + } + }); +} + +arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( + arith::ConstantOp op, Operation* target_op, const int operand_index) { + if (op.getResult().hasOneUse()) { + return op; + } + OpBuilder builder(op->getContext()); + builder.setInsertionPointAfter(op); + arith::ConstantOp new_op = cast(builder.clone(*op)); + target_op->getOpOperand(operand_index).set(new_op.getResult()); + InitializeOperandState(target_op, operand_index, new_op.getResult()); + InitializeResultState(new_op, 0, new_op.getResult()); + return new_op; +} + +bool QuantizationDriver::ShouldCheckBiasScale( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType quantized_type, int& input_index, int& filter_index) { + // For now, restrict scale adjustment to ops with affine quantized weights, + // and having weights and biases as constants. This currently only applies to + // FC and Conv* ops. Restriction for the weight can be relaxed if there are + // needs for adjusting scale of variable weights. + auto affine_op = dyn_cast(op); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + if (!affine_op || !bias_op || input_indices.size() != 2) return false; + if (!mlir::isa(bias_op.getValue())) return false; + filter_index = affine_op.GetAffineOperandIndex(); + if (!op->getOperand(filter_index).getDefiningOp()) { + return false; + } + if (filter_index == input_indices[0]) { + input_index = input_indices[1]; + } else if (filter_index == input_indices[1]) { + input_index = input_indices[0]; + } else { + return false; + } + + const QuantState& input_state = GetOperandQuantState(op, input_index); + const QuantState& filter_state = GetOperandQuantState(op, filter_index); + // If quantization parameter for the filter is fixed, should return it as-is. + // Only checks ops with 8-bit input and weights, and 32-bit biases. + return input_state.params.getStorageTypeIntegralWidth() == 8 && + filter_state.params.getStorageTypeIntegralWidth() == 8 && + quantized_type.getStorageTypeIntegralWidth() == 32; +} + +bool QuantizationDriver::SetBiasParamsWithAdjustments( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType params) { + bool changed = false; + + int input_index; + int filter_index; + if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index, + filter_index)) { + return SetOperandParams(op, bias_index, params); + } + + QuantState input_state = GetOperandQuantState(op, input_index); + QuantState filter_state = GetOperandQuantState(op, filter_index); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + const double input_scale = + mlir::cast(input_state.params).getScale(); + + auto bias_values = mlir::cast(bias_op.getValue()); + // Restrict maximum absolute value of bias within INT_MAX / 2, to make some + // room for accumulator. + if (auto bias_quantized_type = mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + double bias_half_range = 0.0f; + for (auto bias : bias_values.getValues()) { + if (bias_half_range < std::abs(bias.convertToFloat())) { + bias_half_range = std::abs(bias.convertToFloat()); + } + } + if (bias_half_range / bias_quantized_type.getScale() < kBiasMax) { + return SetOperandParams(op, bias_index, params); + } + const double new_bias_scale = + static_cast(bias_half_range) / kBiasMax; + + changed |= SetOperandParams( + op, bias_index, + UniformQuantizedType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scale, 0, + params.getStorageTypeMin(), params.getStorageTypeMax())); + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + if (!filter_op) { + return SetOperandParams(op, bias_index, params); + } + + const auto filter_quantized_type = + mlir::cast(filter_state.params); + changed |= SetOperandParams( + op, filter_index, + UniformQuantizedType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), + new_bias_scale / input_scale, 0, + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } else if (auto bias_quantized_type = + mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + const auto filter_quantized_type = + mlir::cast(filter_state.params); + std::vector new_bias_scales = bias_quantized_type.getScales().vec(); + std::vector new_filter_scales = + filter_quantized_type.getScales().vec(); + + bool needs_adjustment = false; + for (int i = 0; i < bias_quantized_type.getScales().size(); ++i) { + const float abs_bias = std::abs(bias_values.getValues()[i]); + if (abs_bias / new_bias_scales[i] > kBiasMax) { + new_bias_scales[i] = static_cast(abs_bias) / kBiasMax; + new_filter_scales[i] = new_bias_scales[i] / input_scale; + needs_adjustment = true; + } + } + if (!needs_adjustment) { + return SetOperandParams(op, bias_index, params); + } + changed |= SetOperandParams( + op, bias_index, + quant::UniformQuantizedPerAxisType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scales, + bias_quantized_type.getZeroPoints(), + bias_quantized_type.getQuantizedDimension(), + params.getStorageTypeMin(), params.getStorageTypeMax())); + + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + changed |= SetOperandParams( + op, filter_index, + quant::UniformQuantizedPerAxisType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), new_filter_scales, + filter_quantized_type.getZeroPoints(), + filter_quantized_type.getQuantizedDimension(), + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } + return changed; +} + +// This method scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// TODO: b/323478683 - This algorithm assumes there are only one pair of +// `quantfork::QuantizeCastOp` and `quantfork::DequantizeCastOp` ops between two +// quantizable ops. A sanity check should be applied. +void QuantizationDriver::Initialize() { + // Duplicate the bias constant, so the states can be setup correctly. + // TODO: b/323478683 - Function definition should also be duplicated if there + // are multiple call sites. + PreprocessConstantOps(); + + // Setup all the internal states. + SetupAllStates(); +} + +// Propagates the quantization parameters to the operands, results, and biases. +// TODO: b/323478683 - Do not use while loop to handle this logic. +bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { + // TODO: b/323478683 - Use a typed indicator instead of a bool value. + bool changed = false; + while (!work_list_.empty()) { + Operation* op = work_list_.back(); + work_list_.pop_back(); + + // This op has been quantized, so we should not consider it again. + if (quantized_.contains(op)) continue; + quantized_.insert(op); + + if (auto constant_op = dyn_cast(op); constant_op) { + // If the workflow requires inferring ranges from the content + // (post-training quantization) and it is weight (filter) and hasn't + // been quantized, we infer the quantization parameters from the content. + if (infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { + // The quantization parameters are determined by the content of the + // constant. + changed |= SetConstantResultParams(op); + } + continue; + } + + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + + if (scale_spec->has_same_scale_requirement) { + const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); + // The quantization parameters haven't been propagated to any operands + // or results. Skip this node for now. + if (!params) { + quantized_.erase(op); + continue; + } + + // If this is a QDQ conversion only, the op could have a same-scale + // requirement for the floating point kernel but allow per-axis + // quantization for the quantized kernel. If the quantized dimension + // changes, the following logic no longer works as the same `params` + // shouldn't be used for both input and output quantization params. + // E.g. During TransposeOp's quantization propagation in + // PrepareQuantize, if the quantization is per-axis and the + // QuantizedDimension is transposed, then the output q-dq params must + // reflect the new QuantizedDimension. So, check and skip the + // propagation if any of the operands has a per-axis quantized type param + // and `RequiredSameQuantizedAxes` set to false. + // Currently, these lines of code are only applicable to TFL_TransposeOp + // and TFL_ReshapeOp. And the output q-dq propagation for this Op is + // performed in `PropagateTransposedPerAxisQuantDim` and + // `PropagateReshapedPerAxisQuantDim` respectively. + if (is_qdq_conversion_ && + !scale_spec->required_same_quantized_axes_func()) { + if (HasPerAxisQuantizedOperand(op)) continue; + } + + // Use the final state to set all the operands' parameters. + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto type = + mlir::dyn_cast(op->getOperand(i).getType())) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float tensors. + if (mlir::isa(type.getElementType())) + changed |= SetOperandParams(op, i, params); + } + } + + // Use the final state to set all the results' parameters. + for (int i = 0; i < op->getNumResults(); ++i) + if (auto type = mlir::dyn_cast(op->getResult(i).getType()); + type != nullptr) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float-tensors. + if (mlir::isa(type.getElementType())) + changed |= SetResultParams(op, i, params); + } + } + + // If the model already contains immutable QDQs, require upstream to + // explicitly fix output range instead. + if (scale_spec->has_fixed_output_range && infer_tensor_range_ && + !is_qdq_conversion_) { + // Infer ranges from the activation ops. This is usually required for + // the post-training quantization workflow. + // TODO: b/323478683 - Different result can have different fixed range. + const QuantizedType params = + scale_spec->fixed_output_range_func(is_signed_, bit_width_); + for (auto i = 0; i < op->getNumResults(); ++i) { + // The range is null if the result has been quantized. + if (params) { + changed |= SetResultParams(op, i, params); + } + } + } + + const std::unique_ptr spec = GetQuantSpec(op); + for (const auto& [bias_operand_idx, non_bias_params] : + spec->biases_params) { + const auto& [non_bias_operand_indices, accumulator_scale_func] = + non_bias_params; + const QuantizedType params = + GetBiasParams(op, bias_operand_idx, non_bias_operand_indices, + accumulator_scale_func); + if (!params) { + quantized_.erase(op); + continue; + } + changed |= SetBiasParamsWithAdjustments(op, bias_operand_idx, + non_bias_operand_indices, params); + } + } + + return changed; +} + +// Finalizes the arguments and result states in the function. +void QuantizationDriver::Finalize() { + for (BlockArgument arg : args_) { + const QuantState& state = GetArgQuantState(arg); + RequantizeStates& requantizes = GetArgRequantizeStates(arg); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeArg(arg, state.params); + } + + if (!requantizes.empty()) { + RequantizeArg(arg, requantizes); + } + } + + for (const auto& [op_with_result_idx, quant_state_idx] : result_states_) { + const auto [op, result_idx] = op_with_result_idx; + const QuantState& state = GetResultQuantState(op, result_idx); + RequantizeStates& requantizes = GetResultRequantizeStates(op, result_idx); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeOpResult(op, result_idx, state.params); + } + + if (!requantizes.empty()) { + RequantizeOpResult(op, result_idx, requantizes); + } + } +} + +// Runs quantization in following steps: +// 1. Scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// 2. Propagates the quantization parameters to the operands, results, and +// biases. +// 3. Finalizes the arguments and result states in the function. +void QuantizationDriver::Run() { + Initialize(); + if (PropagateParamsAndReturnIfChanged()) { + Finalize(); + } +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + ApplyQuantizationParamsPropagation( + func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + is_qdq_conversion); +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + QuantizationDriver(func, is_signed, bit_width, disable_per_channel, + op_quant_spec_getter, op_quant_scale_spec_getter, + infer_tensor_ranges, legacy_float_scale, is_qdq_conversion) + .Run(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h new file mode 100644 index 00000000000000..18d156ec8aa31d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h @@ -0,0 +1,387 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" + +namespace mlir { +namespace TFL { + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantizedType params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() const { return params == nullptr; } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantizedType params; + + // Avoid clobbering all uses of the value, limit to just these ops. + SmallVector> users; +}; + +using RequantizeStates = SmallVector; + +// This is a worklist-driven driver for propagating quantization parameters +// across operations. +// +// The initial quantization parameters are extracted from the quantized type +// between adjacent `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp`s. All these initial parameters are marked as +// immutable because they are from quantization-aware training. +// +// The algorithm traverses each op and sets the quantization parameters of its +// operands and results, according to its quantization specification, and then +// adds the operands and results to the worklist. If there are any conflicts +// (for example, there are quantization parameters propagated from the previous +// iteration), this process stops if the existing parameters are the immutable, +// or adding `requantize` op to resolve the conflicts. +// +// After the algorithm is converged, pairs of `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp` are inserted to the right position to +// materialize the propagation and requantize results. +// +class QuantizationDriver { + public: + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), + is_signed_(is_signed), + bit_width_(bit_width), + disable_per_channel_(disable_per_channel), + op_quant_spec_getter_(op_quant_spec_getter), + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + infer_tensor_range_(infer_tensor_range), + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} + + // The entry point of the quantization parameters propagation. + void Run(); + + // Sets up the states for all the op results in the function. + void Initialize(); + + // Propagates the quantization parameters across all the ops. + bool PropagateParamsAndReturnIfChanged(); + + // Inserts the Quantize and Dequantize ops according to the propagation + // result. + void Finalize(); + + SmallVector GetArgs() { return args_; } + + llvm::DenseMap, int> GetResultStates() { + return result_states_; + } + + DenseMap result_states_; + + // Returns the state of the block argument. + QuantState& GetArgQuantState(BlockArgument arg) { + return states_[arg_states_[arg]]; + } + + // Returns the state of the index-th result of the op. + QuantState& GetResultQuantState(Operation* op, const int index) { + return states_[result_states_[{op, index}]]; + } + + private: + // Duplicates the constant op if it has multiple uses, and replaces + // target_op->operand[operand_index] with the newly created op. This also + // replaces corresponsing quantization states. + arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op, + Operation* target_op, + int operand_index); + + // Adjusts bias scale that is derived from other scales (fc, conv ops) to + // prevent overflow of quantized bias values. This also changes quantization + // state of other inputs when needed. + bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType params); + + // Checks preconditions to adjust bias scale. + bool ShouldCheckBiasScale(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, + int& filter_index); + + // Preprocesses the constants by doing the following: + // - Duplicates constants if it is used by multiple ops. For example, if a + // constant is used by multiple ops as a bias, duplicate constants and + // let each op assign its own quantization parameter for bias. + // - Adds all the non-bias constants (weights) to a set for looking up + // later. + // - Adds all per-channel weights to a set for looking up later. + void PreprocessConstantOps(); + + // Sets up all the data structures for quantization propagation. + void SetupAllStates(); + + // Returns Whether the constant is a weight, which shouldn't be shared by + // different ops. + bool IsWeight(Operation* cst) { return llvm::is_contained(weights_, cst); } + + // Returns all the related quantization constraints of the op. + std::unique_ptr GetQuantSpec(Operation* op); + std::unique_ptr GetQuantScaleSpec(Operation* op); + + // Returns whether quantization parameters have been propagated to the results + // of this op. + bool IsQuantized(Operation* op); + + // Adds all the users of index-th result of op to the work list. + void AddUserToList(Operation* op, const int index) { + for (Operation* user : op->getResult(index).getUsers()) { + work_list_.push_back(user); + } + } + + // Adds the defining op of index-th operand of op to the work list. + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); + } + } + + // Returns the quantization params for the bias input from the non-bias + // operands which have their indexes in the `non_biases` vector. The returned + // parameters are calculated by `func`. + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any + // quantization parameters have been propagated, a `requantize` will happen on + // the output of propagated quantization. When `override` is set, quantization + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); + + // Sets the quantization parameters of the constant result according to its + // content. + bool SetConstantResultParams(Operation* op); + + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); + + // Inserts the Quantize ops for requantizing the index-th result of the op. + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); + + // Inserts the Quantize ops for requantizing a block argument. + void RequantizeArg(BlockArgument arg, RequantizeStates& states); + + // Inserts the Quantize and Dequantize ops to quantize the value and returns + // the Quantize op. + void RequantizeValue(Value value, RequantizeStates& states, Location loc); + + // Returns the quantization parameter satisfies the same scale + // constraints for the op. Returns an empty option if this quantization + // parameter doesn't exist. + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); + + // Returns the state of the index-th operand of the op. + QuantState& GetOperandQuantState(Operation* op, const int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th operand of the op. + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th result of the op. + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + // Returns the states of the arg. + RequantizeStates& GetArgRequantizeStates(BlockArgument arg) { + return rescale_states_[arg_states_[arg]]; + } + + // Sets the state of an argument. If this value is cached, uses the cached + // result without creating new entry in the state vector. Otherwise, allocate + // a new entry in the state vector. + void InitializeArgState(BlockArgument arg, Value arg_value); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(Operation* op, int index, Value value); + + // Sets the state of the index-th result of the op. If this result is cached, + // uses the cached result without creating new entry in the state vector. + // Otherwise, allocate a new entry in the state vector. + void InitializeResultState(Operation* op, int index, Value value); + + func::FuncOp fn_; + OpBuilder builder_; + const bool is_signed_; + const int bit_width_; + const bool disable_per_channel_; + + // We should distinguish weights and bias constants. Biases are specified by + // the quantization spec or are the operands of ops with same scale spec. The + // rest are weights. + DenseSet weights_; + + // The weights require narrow_range quantization. This map collects all the + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; + + // All the ops needs to propagate the quantization parameters to. + std::vector work_list_; + absl::flat_hash_set quantized_; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map are + // the values from `operand_states_` and `result_state_`. + absl::flat_hash_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + DenseMap operand_states_; + DenseMap arg_states_; + DenseMap value_to_state_; + + // This vector is to preserve the arguments order, so the newly inserted + // quantized ops for the arguments are deterministically ordered. + SmallVector args_; + + OpQuantSpecGetter op_quant_spec_getter_; + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + + // Infer output ranges for activation ops and constants. This is usually + // required for post-training quantization. + const bool infer_tensor_range_; + + // Calculate scales in float instead of double, so that the scales and + // quantized values are exactly the same with the TOCO quantizer. + const bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + const bool is_qdq_conversion_; +}; + +// Propagates quantization parameters across ops in this function and satisfies +// the quantization specification of the ops. This methods assumes the initial +// quantization parameters are stored as adjacent quantize and dequantize ops +// and the propagation results are materialized by inserting pairs of quantize +// and dequantize ops to this function. Set `disable_per_channel` to true to not +// use per channel quantization even the op supports it. +// Setting `infer_tensor_range` to true, to infer quantization parameters from +// the activation ops and weight constants. This is only used for post-training +// quantization. +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, + int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool infer_tensor_ranges, + bool legacy_float_scale, + bool is_qdq_conversion); + +void ApplyQuantizationParamsPropagation( + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, + bool legacy_float_scale, bool is_qdq_conversion); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc new file mode 100644 index 00000000000000..59ca182bd418be --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::TFL { +namespace { + +using ApplyQuantizationParamsPropagationTest = + mlir::quant::QuantizationTestBase; +using ::testing::IsEmpty; +using ::testing::Not; + +constexpr absl::string_view kModuleTFLite = R"mlir( + module { + func.func @main(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {_from_xla_call_module} { + %cst_0 = arith.constant dense<1.0> : tensor<3x1x1x3xf32> + %cst_1 = arith.constant dense<2.0> : tensor<3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_2, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %1 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_1(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_2(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + } +)mlir"; + +// TOOD: b/323478683 - Directly use types rather than creating a `unique_ptr`. +std::unique_ptr GetOpQuantSpec( + const mlir::Operation* op, + bool disable_per_channel_for_dense_layers = false) { + auto spec = std::make_unique(); + spec->coeff_op_quant_dim[1] = 3; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + for (const auto& [key, value] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(key); + } + return spec; +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + ConstsUsedMultipleTimesAreDuplicated) { + const OwningOpRef module_op_ref = + mlir::quant::QuantizationTestBase::ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + + int64_t num_constant_op = 0; + main_fn.walk([&](arith::ConstantOp cst) { ++num_constant_op; }); + EXPECT_EQ(num_constant_op, 4); +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + PropagateParamsCreatesQuantState) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + ASSERT_TRUE(quantization_driver.PropagateParamsAndReturnIfChanged()); + EXPECT_THAT(quantization_driver.GetArgs(), Not(IsEmpty())); + + for (const auto& arg : quantization_driver.GetArgs()) { + const QuantState& state = quantization_driver.GetArgQuantState(arg); + EXPECT_TRUE(isa(state.params)); + } + for (const auto& result : quantization_driver.GetResultStates()) { + mlir::Operation* op = result.first.first; + const int res_index = result.first.second; + const QuantState state = + quantization_driver.GetResultQuantState(op, res_index); + EXPECT_TRUE(isa(state.params)); + } +} + +TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + ApplyQuantizationParamsPropagation( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + mlir::Operation* xla_call_module_op = + quant::FindOperationOfType(main_fn); + mlir::Operation* filter_dcast_op = + xla_call_module_op->getOperand(1).getDefiningOp(); + mlir::Operation* filter_qcast_op = + filter_dcast_op->getOperand(0).getDefiningOp(); + ASSERT_NE(filter_qcast_op, nullptr); + EXPECT_TRUE(isa(filter_qcast_op)); + EXPECT_TRUE(isa(filter_dcast_op)); + EXPECT_TRUE(isa( + mlir::cast(filter_qcast_op->getResult(0).getType()) + .getElementType())); +} + +} // namespace +} // namespace mlir::TFL diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h new file mode 100644 index 00000000000000..efdf5bf6223038 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +using QuantizedType = mlir::quant::QuantizedType; +using UniformQuantizedType = mlir::quant::UniformQuantizedType; + +namespace mlir { +namespace TFL { +// Verifies that the op satisfies the same operands and results scales +// constraints. Note that this constraint can only be applied on some +// storage types of the op. +LogicalResult VerifySameScales(Operation* op); +} // namespace TFL + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_interface.h.inc" + +namespace OpTrait { +namespace TFL { + +// The base class that all the quantization related OpTrait implements. +template class TraitType> +struct QuantizationSpecTraitBase : public TraitBase { + static bool IsBias(int index) { return false; } + static bool IsQuantizable() { return true; } +}; + +// This class provides the API for ops that has a fixed output value range. +// This is used as a trait like this: +// +// class SoftmaxOp +// : public Op::Impl> { +// +// TODO(fengliuai): create a better way to express floating point scale in the +// template argument list. +template +class FixedResultUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, FixedResultUniformScale< + BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, + StorageTypeMin, StorageTypeMax, Sign>::Impl> { + public: + QuantizedType GetResultQuantizedType(int index) { + auto op = this->getOperation(); + const auto result_type = + op->getResult(index).getType().template cast(); + if (!result_type.getElementType().template isa()) return {}; + Builder builder(op->getContext()); + const IntegerType storage_type = builder.getIntegerType(BitWidth); + const double scale = static_cast(ScaleMantissa) * + std::pow(10.0, static_cast(ScaleExp)); + return UniformQuantizedType::getChecked( + Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, + StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); + } + }; +}; + +// This class provides the API for ops that has input as bias. This is used +// as a trait like this: +// +// class Conv2DOp +// : public Op::Impl> +// +// TODO(fengliuai): supports a configurable accumulator bit width. +template +class AccumulatorUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, AccumulatorUniformScale::Impl> { + public: + // Whether the index-th operand is a bias. + static bool IsBias(int index) { return index == Bias; } + + // Returns the indexes of all the non-bias operands. + static std::vector GetAllNonBiasOperands() { + return std::vector({Operands...}); + } + }; +}; + +// The trait to specify the operand index of the coefficient for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +// +// class Conv2DOp +// : public Op::Impl> +// +template +class AffineOpCoefficient { + public: + template + class Impl + : public TraitBase::Impl> { + public: + static int GetCoefficientOperandIndex() { return OperandIndex; } + static int GetQuantizationDim() { return QuantDim; } + }; +}; + +// This class provides the API for ops that can be quantized. +// This is as a trait like this: +// +// class LessOp : public Op { +// +template +class QuantizableResult + : public QuantizationSpecTraitBase {}; + +} // namespace TFL +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc new file mode 100644 index 00000000000000..3754ae7fb47823 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc @@ -0,0 +1,1075 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" + +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_interface.cc.inc" + +namespace TFL { +namespace { + +constexpr double kSmallestHalfRange = kNearZeroTolerance / 2; +using QType = quant::QuantizedType; + +// Repeats the content of `data` multiple times to resize to `target_size`. +// Note that this only broadcast across one dimension. +template +bool BroadcastVector(int target_size, SmallVectorImpl& data) { + const int size = data.size(); + if (size != target_size) { + if (target_size % size != 0) return true; + data.reserve(target_size); + for (int i = 1; i < target_size / size; ++i) { + data.insert(data.end(), data.begin(), data.begin() + size); + } + } + return false; +} + +// Expands the range to be larger than or equal to 1.0e-6, if it is +// very small (< 1.0e-6). This is to prevent very large quantized value by this +// range. +void ExpandVerySmallRange(const ArrayRef mins, + const ArrayRef maxs, + SmallVectorImpl& effective_mins, + SmallVectorImpl& effective_maxs) { + for (const auto [min, max] : llvm::zip(mins, maxs)) { + // The range is small. Expands the range to stride 0.0 and also at least + // 1.0e-6. + if (max - min > kNearZeroTolerance) { + effective_mins.push_back(min); + effective_maxs.push_back(max); + } else { + effective_mins.push_back(std::min(min, -kSmallestHalfRange)); + effective_maxs.push_back(std::max(max, kSmallestHalfRange)); + } + } +} + +// Sets the min / max, scale and zero_points from the fake quant num_bits +// attribute from QAT. +QuantizedType ResetMinMaxFromNumBits(const QuantizedType type, + const int num_bits, + const bool narrow_range, + const bool is_signed) { + if (num_bits >= 8) { + return type; + } + int64_t qmin = QType::getDefaultMinimumForInteger(is_signed, num_bits); + int64_t qmax = QType::getDefaultMaximumForInteger(is_signed, num_bits); + if (narrow_range) { + qmin += 1; + } + const int64_t storage_type_min = type.getStorageTypeMin(); + const int64_t storage_type_max = type.getStorageTypeMax(); + const double rate = + static_cast(storage_type_max - storage_type_min) / (qmax - qmin); + const auto& recalculate_scale = [&](double scale) -> double { + return scale * rate; + }; + const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t { + return qmax - std::round((storage_type_max - zero_point) / rate); + }; + if (auto q_type = dyn_cast(type)) { + const double scale = recalculate_scale(q_type.getScale()); + const double zero_point = recalculate_zero_point(q_type.getZeroPoint()); + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scale, + zero_point, qmin, qmax); + } else if (auto q_type = dyn_cast(type)) { + const int size = q_type.getScales().size(); + SmallVector scales(size); + SmallVector zero_points(size); + for (int i = 0; i < size; ++i) { + scales[i] = recalculate_scale(q_type.getScales()[i]); + zero_points[i] = recalculate_zero_point(q_type.getZeroPoints()[i]); + } + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), qmin, qmax); + } else { + llvm_unreachable("Unsupported QuantizedType in ResetMinMaxFromNumBits"); + } + return type; +} + +// Changes the axis of the input per-channel quantized type to match the +// dimension of the target type. Returns nullptr if it fails. +quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( + const ArrayRef shape, + const quant::UniformQuantizedPerAxisType qtype, const Type target, + const int quant_dim) { + const auto shaped = dyn_cast(target); + if (!shaped) return {}; + const ArrayRef new_shape = shaped.getShape(); + + SmallVector scales(qtype.getScales().begin(), + qtype.getScales().end()); + SmallVector zero_points(qtype.getZeroPoints().begin(), + qtype.getZeroPoints().end()); + + if (new_shape.size() == shape.size()) { // same rank + // Broadcast the scales and zero points to match the target size, which is + // usually the axis-th dimension of the target type. Currently, it covers + // two cases: + // - for Transpose, the data layout is changed so the `dim[axis]` still + // equals to the `scales_size`. The broadcast skips; + // - for Reshape, the data layout isn't changed but the innermost dimension + // is expand to cover the last two original dimensions. Thus we just need to + // be repeated the `scales` dim[2] times to covers the new dim length. + if (BroadcastVector(shaped.getDimSize(quant_dim), scales) || + BroadcastVector(shaped.getDimSize(quant_dim), zero_points)) { + return {}; + } + } else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) { + // Handle the [A, B, C] -> [1, A, B, C] reshape case. + if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) && + quant_dim == new_shape.size() - 1)) { + return {}; + } + } else { + return {}; + } + + return quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + scales, zero_points, quant_dim, qtype.getStorageTypeMin(), + qtype.getStorageTypeMax()); +} + +} // namespace + +bool IsOpQuantizable(Operation* op) { + if (isa(op)) { + // Constant ops do not have QuantizableResult attribute but they can deal + // with quantized tensors. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + + const bool attr_output_quantized = QuantizableOpSupportsFloatOutputType(op); + + const bool trait_enforced_quantizable = + op->hasTrait(); + + return attr_enforced_quantizable || trait_enforced_quantizable || + attr_output_quantized; +} + +// Checks if an op has specific attributes that enable quantized inputs with +// float outputs. +bool QuantizableOpSupportsFloatOutputType(Operation* op) { + static constexpr char kOutputTypes[] = "_output_types"; + static constexpr char kSupportOutputTypeFloat[] = + "_support_output_type_float_in_quantized_op"; + + if (!(op->hasAttrOfType(kOutputQuantized) && + op->getAttrOfType(kOutputQuantized).getValue())) { + return false; + } + + if (!(op->hasAttrOfType(kSupportOutputTypeFloat) && + op->getAttrOfType(kSupportOutputTypeFloat) + .getValue())) { + return false; + } + + if (!op->hasAttrOfType(kOutputTypes)) { + return false; + } + + auto output_types_attr = op->getAttrOfType(kOutputTypes); + + if (output_types_attr.size() != op->getResultTypes().size()) { + return false; + } + + for (const auto [attr_element, result_type] : + llvm::zip_equal(output_types_attr, op->getResultTypes())) { + auto type_attr = mlir::dyn_cast_or_null(attr_element); + + if (!type_attr) { + return false; + } + + auto tensor_type = mlir::dyn_cast_or_null(result_type); + + if (!tensor_type) { + return false; + } + + if (type_attr.getValue() != tensor_type.getElementType()) { + return false; + } + } + + return true; +} + +// Returns the quantized type for the +// input_type/min/max/storag_type_width/narrow_range. +// This is entry point to the Quant dialect and used for both quantizing +// activations and weights. +Type GetQuantizedType(Builder builder, const Type input_type, + const ArrayRef min, const ArrayRef max, + const int quant_dim, const int storage_type_width, + const bool narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + auto converter = + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType(input_type); + + // Expand the range to prevent extremely small scales and large quantized + // integers which can cause overflow. This leads to scale + // 7.843137254901961e-9 with 8 bits. + SmallVector effective_mins, effective_maxs; + ExpandVerySmallRange(min, max, effective_mins, effective_maxs); + + quant::QuantizedType quantized_element_type; + if (min.size() == 1 && max.size() == 1 && quant_dim == -1) { + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, effective_mins[0], + effective_maxs[0], narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins[0], + effective_maxs[0], builder.getUnknownLoc()); + } + } else if (min.size() == max.size()) { + auto shape = dyn_cast(input_type); + if (!shape || shape.getRank() <= quant_dim || + static_cast(min.size()) != shape.getDimSize(quant_dim)) { + return {}; + } + // The quantization dim is set to the last dimension. + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins, + effective_maxs, narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins, effective_maxs, + builder.getUnknownLoc()); + } + } + if (!quantized_element_type) return {}; + // Use fake quant configured bit-widths (only supported for + // 1 < num_bits < 8 bits) instead of using 8-bit defaults. + if (use_fake_quant_num_bits && storage_type_width > 1 && + storage_type_width < 8 && + quantized_element_type.getStorageTypeMax() > + QType::getDefaultMinimumForInteger(is_signed, storage_type_width)) { + const auto resetEleType = ResetMinMaxFromNumBits( + quantized_element_type, storage_type_width, narrow_range, is_signed); + return converter.convert(resetEleType); + } + return converter.convert(quantized_element_type); +} + +// TODO(fengliuai): promote this utility method to mlir QuantOps. +TypeAttr RescaleQuantizedType(const Type input, const Attribute factor) { + const auto factor_values = dyn_cast_or_null(factor); + if (!factor_values) return {}; + const auto element_type = + quant::QuantizedType::getQuantizedElementType(input); + if (!element_type) return {}; + if (auto qtype = dyn_cast(element_type)) { + const ArrayRef scales = qtype.getScales(); + // Broadcasting hasn't been implemented yet. + if (static_cast(scales.size()) != factor_values.getNumElements()) + return {}; + SmallVector new_scales; + new_scales.reserve(scales.size()); + auto scales_iter = scales.begin(); + for (const auto& f : factor_values) { + new_scales.push_back(*scales_iter * + std::fabs(FloatAttr::getValueAsDouble(f))); + ++scales_iter; + } + // We are assuming symmetric quantization. + auto new_ele_type = quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + if (const auto new_type = new_ele_type.castFromExpressedType( + quant::QuantizedType::castToExpressedType(input))) { + return TypeAttr::get(new_type); + } + } + // Currently, we only support per-axis quantized type. + return {}; +} + +TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, + const Attribute min, const Attribute max, + const int quant_dim, const IntegerAttr num_bits, + const BoolAttr narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + SmallVector min_value, max_value; + const auto mins = dyn_cast(min); + const auto maxs = dyn_cast(max); + if (mins && maxs) { + min_value.reserve(mins.getNumElements()); + max_value.reserve(maxs.getNumElements()); + for (auto it = mins.begin(); it != mins.end(); ++it) { + min_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + for (auto it = maxs.begin(); it != maxs.end(); ++it) { + max_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + } else { + const auto fmin = dyn_cast(min); + const auto fmax = dyn_cast(max); + if (fmin && fmax) { + min_value.push_back(fmin.getValueAsDouble()); + max_value.push_back(fmax.getValueAsDouble()); + } else { + return {}; + } + } + const Type final_type = + GetQuantizedType(builder, input_type, min_value, max_value, quant_dim, + num_bits.getInt(), narrow_range.getValue(), is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +TypeAttr CastQuantizedTypeAttrFromExpressedType(const Builder builder, + const TypeAttr source, + const Type target, + const int axis) { + const auto source_type = dyn_cast_or_null(source.getValue()); + if (!source_type) return {}; + const auto src_ele_type = source_type.getElementType(); + auto qtype = dyn_cast(src_ele_type); + + // Reset the quantization dimensions if it is per-axis. + if (const auto per_axis = + dyn_cast_or_null(qtype)) { + // For the pass-through ops, we don't know which the dimension will be the + // new quantization dimension. Only if the new quantization dimension can + // be inferred, it is safe to reset the per-axis quantized type. + if (axis == -1) return {}; + qtype = + ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis); + } + if (!qtype) return {}; + const Type final_type = qtype.castFromExpressedType(target); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +void ExtractMinMaxFromAttr(const DenseFPElementsAttr values, const int dim_size, + const int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs) { + // If all the element values are same we don't need to scan the content. + if (values.isSplat()) { + const double single_value = + FloatAttr::getValueAsDouble(values.getSplatValue()); + + // When the single value isn't 0.0, we expand it to a range to include + // this single value and 0.0. This will give us a scale and zero point + // works for both this value and 0.0. + if (single_value < 0.0) { + mins[0] = single_value; + maxs[0] = symmetric ? -single_value : 0.0; + } else if (single_value > 0.0) { + mins[0] = symmetric ? -single_value : 0.0; + maxs[0] = single_value; + } else { + mins[0] = maxs[0] = single_value; + } + for (int i = 1; i < dim_size; ++i) { + mins[i] = mins[0]; + maxs[i] = maxs[0]; + } + } else { + int64_t flatten_index = 0; + auto begin = values.begin(); + auto end = values.end(); + for (auto it = begin; it != end; ++it, ++flatten_index) { + const double ele_value = FloatAttr::getValueAsDouble(*it); + const int slice_index = flatten_index / slice_size; + const int channel_index = slice_index % dim_size; + mins[channel_index] = std::min(mins[channel_index], ele_value); + maxs[channel_index] = std::max(maxs[channel_index], ele_value); + } + // Expand range to include 0. + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(maxs[i], 0.0); + mins[i] = std::min(mins[i], 0.0); + } + if (symmetric) { + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i])); + mins[i] = -maxs[i]; + } + } + } +} + +Type GetUniformQuantizedTypeForWeight( + const ElementsAttr attr, const bool symmetric, const unsigned num_bits, + const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + SmallVector mins(1, std::numeric_limits::max()); + SmallVector maxs(1, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins, + maxs); + + const auto type = + GetQuantizedType(builder, attr.getType(), mins[0], maxs[0], + /*quant_dim=*/-1, num_bits, narrow_range, is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (const auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +Type GetUniformQuantizedPerAxisTypeForWeight( + const ElementsAttr attr, const int quant_dim, const bool symmetric, + const unsigned num_bits, const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + const auto shape = cast(attr.getType()).getShape(); + if (static_cast(shape.size()) <= quant_dim) return {}; + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + const int dim_size = shape[quant_dim]; + const int slice_size = + std::accumulate(std::next(shape.begin(), quant_dim + 1), shape.end(), 1, + std::multiplies()); + SmallVector mins(dim_size, std::numeric_limits::max()); + SmallVector maxs(dim_size, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs); + + const auto type = GetQuantizedType( + builder, attr.getType(), mins, maxs, quant_dim, num_bits, narrow_range, + is_signed, legacy_float_scale, use_fake_quant_num_bits); + if (auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, + const int adjusted_quant_dim, const bool legacy_float_scale) { + if (op_types.empty()) return {}; + + size_t axis_size = 1; + int32_t quant_dim = -1; + Type expressed_type; + // Requires all the op types are valid UniformQuantizedTypes or + // UniformQuantizedPerAxisTypes and also have same expressed type. For all + // the UniformQuantizedPerAxisTypes, the quantization dimension index and + // dimension sizes are same. + for (const auto op_type : op_types) { + if (!op_type) return {}; + if (expressed_type && expressed_type != op_type.getExpressedType()) { + return {}; + } + expressed_type = op_type.getExpressedType(); + + if (const auto type = + dyn_cast(op_type)) { + if (axis_size != 1 && axis_size != type.getScales().size()) return {}; + if (quant_dim != -1 && quant_dim != type.getQuantizedDimension()) + return {}; + axis_size = type.getScales().size(); + quant_dim = type.getQuantizedDimension(); + } else if (!isa(op_type)) { + return {}; + } + } + + // The scale from the UniformQuantizedTypes is broadcasted if there are + // UniformQuantizedPerAxisTypes. + SmallVector scales(axis_size, 1.0); + for (const auto op_type : op_types) { + if (const auto type = + dyn_cast(op_type)) { + for (const auto& index_scale : llvm::enumerate(type.getScales())) { + scales[index_scale.index()] *= index_scale.value(); + } + } else if (const auto type = + dyn_cast(op_type)) { + for (int index = 0; index < axis_size; ++index) { + scales[index] *= type.getScale(); + } + } + } + if (legacy_float_scale) { + for (int i = 0; i < scales.size(); ++i) { + scales[i] = static_cast(scales[i]); + } + } + + // Builds the result quantized type, which has signed 32 bits storage type. + Builder builder(expressed_type.getContext()); + const IntegerType storage_type = builder.getIntegerType(32); + const int64_t storage_type_min = + quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32); + const int64_t storage_type_max = + quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32); + if (axis_size == 1) { + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales[0], + /*zeroPoint=*/0, storage_type_min, storage_type_max); + } else { + SmallVector zero_points(axis_size, 0); + // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. + // If the bias rank is larger than 1 because it was already broadcasted + // to match the output shape, use the last index. + return quant::UniformQuantizedPerAxisType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales, zero_points, + /*quantizedDimension=*/std::max(adjusted_quant_dim, 0), + storage_type_min, storage_type_max); + } +} + +ElementsAttr QuantizeLegacy(const Attribute real_value, + const Type tensor_type) { + if (!isa(real_value) || + !quant::QuantizedType::getQuantizedElementType(tensor_type)) { + return {}; + } + const auto real_values_attr = cast(real_value); + auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type); + std::vector real_values; + SmallVector quantized_attr; + real_values.reserve(real_values_attr.getNumElements()); + quantized_attr.reserve(real_values_attr.getNumElements()); + std::transform(real_values_attr.begin(), real_values_attr.end(), + std::back_inserter(real_values), [&](APFloat value) -> float { + return value.convertToFloat(); + }); + const ShapedType new_dense_type = dyn_cast_or_null( + q_type.castExpressedToStorageType(real_values_attr.getType())); + const int width = dyn_cast(q_type.getStorageType()).getWidth(); + + if (width == 8 && q_type.getStorageTypeMax() == 127 && + q_type.getStorageTypeMin() == -127) { + std::vector quantized_values(real_values_attr.getNumElements()); + if (auto uniform_type = dyn_cast(q_type)) { + float min, max, scale; + mlir::lite::toco_legacy::PortableSymmetricQuantizeFloats( + real_values.data(), real_values.size(), quantized_values.data(), &min, + &max, &scale); + // The scale has been adjusted, so the adjusted scale should be respected. + if (std::abs(scale - uniform_type.getScale()) > 1e-3) { + return Quantize(real_value, tensor_type); + } + } else if (auto uniform_type = + dyn_cast(q_type)) { + std::vector scales_inv; + std::vector dimension; + dimension.insert(dimension.end(), new_dense_type.getShape().begin(), + new_dense_type.getShape().end()); + std::transform(uniform_type.getScales().begin(), + uniform_type.getScales().end(), + std::back_inserter(scales_inv), + [](float scale) { return 1.0 / scale; }); + + tflite_migration::optimize::utils::SymmetricPerChannelQuantizeValues( + real_values.data(), scales_inv, dimension, + uniform_type.getQuantizedDimension(), &quantized_values); + } else { + return {}; + } + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int8_t value) -> APInt { + return APInt(8, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } else if (width == 8) { + // This can be a state tensor, or an actual constant tensor with + // asymmetric range. For a state tensor, assigning correct quantization + // parameters is sufficient, and for constants with asymmetric range it's + // not correctly quantized by legacy quantizer so call the new Quantize. + return Quantize(real_value, tensor_type); + } else if (width == 16) { + if (const auto uniform_type = dyn_cast(q_type)) { + const auto quantized_values = + tflite_migration::optimize::utils::SymmetricQuantizeFloatsToInt16( + real_values.data(), real_values.size(), uniform_type.getScale()); + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int16_t value) -> APInt { + return APInt(16, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + } else if (width == 32) { + std::vector scales; + if (const auto uniform_type = dyn_cast(q_type)) { + scales.push_back(uniform_type.getScale()); + } else if (const auto uniform_type = + dyn_cast(q_type)) { + scales.insert(scales.end(), uniform_type.getScales().begin(), + uniform_type.getScales().end()); + } else { + return {}; + } + const auto quantized_bias = + tflite_migration::optimize::utils::SymmetricBiasQuantize( + real_values.data(), real_values.size(), scales); + std::transform(quantized_bias.begin(), quantized_bias.end(), + std::back_inserter(quantized_attr), + [&](int32_t value) -> APInt { + return APInt(32, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + return {}; +} + +ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { + if (const auto q_type = + quant::QuantizedType::getQuantizedElementType(tensor_type)) { + Type converted_type; + return dyn_cast_or_null( + quantfork::quantizeAttr(real_value, q_type, converted_type)); + } + return {}; +} + +quant::QuantizedType DownCastScale(QuantizedType type, double min, double max, + Location loc) { + const SmallVector mins = {min}; + const SmallVector maxs = {max}; + return DownCastScale(type, mins, maxs, loc); +} + +quant::QuantizedType DownCastScale(QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc) { + // The given type can be null. For example, there can be an invalid scale and + // so on. + if (!type) return type; + SmallVector scales(mins.size()); + SmallVector zero_points(mins.size()); + if (auto q_type = dyn_cast(type)) { + zero_points.push_back(q_type.getZeroPoint()); + } else if (auto q_type = dyn_cast(type)) { + zero_points = {q_type.getZeroPoints().begin(), + q_type.getZeroPoints().end()}; + } + for (int i = 0; i < mins.size(); ++i) { + scales[i] = (static_cast(maxs[i]) - static_cast(mins[i])) / + (type.getStorageTypeMax() - type.getStorageTypeMin()); + if (type.getStorageTypeMax() != -type.getStorageTypeMin()) { + // Only applies for asymmetric quantized range with original scale. + const float zero_point_from_min = + type.getStorageTypeMin() - mins[i] / scales[i]; + if (zero_point_from_min < type.getStorageTypeMin()) { + zero_points[i] = static_cast(type.getStorageTypeMin()); + } else if (zero_point_from_min > type.getStorageTypeMax()) { + zero_points[i] = static_cast(type.getStorageTypeMax()); + } else { + zero_points[i] = static_cast(std::round(zero_point_from_min)); + } + } + } + if (auto q_type = dyn_cast(type)) { + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scales[0], + zero_points[0], q_type.getStorageTypeMin(), + q_type.getStorageTypeMax()); + } else if (auto q_type = dyn_cast(type)) { + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), + q_type.getStorageTypeMin(), q_type.getStorageTypeMax()); + } + return type; +} + +// A heuristic to determine whether the scales needs to be from operands or +// from results for the ops with the `SameOperandsAndResultsScale` property. +// The current implementation is based on the number of operands. +static bool PreferResultScale(Operation* op) { + int float_operands = 0; + for (auto operand : op->getOperands()) { + if (auto operand_type = dyn_cast(operand.getType())) { + if (isa(operand_type.getElementType())) { + if (++float_operands > 1) return true; + } + } + } + return false; +} + +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op) { + auto spec = std::make_unique(); + if (isa(op)) { + spec->has_same_scale_requirement = true; + spec->required_same_scale_func = [op](const bool sign, + const int bit_width) { + return cast(op) + .RequiredSameOperandsAndResultsScale(sign, bit_width); + }; + spec->required_same_quantized_axes_func = [op]() { + return cast(op).RequiredSameQuantizedAxes(); + }; + } + if (isa(op)) { + spec->has_fixed_output_range = true; + spec->fixed_output_range_func = [op](bool sign, int bit_width) { + return cast(op).GetFixedOutputRange(sign, + bit_width); + }; + } + return spec; +} + +// The stats op of some of the ops can be redundant. The current implementation +// only considers the ops with restricted output params. +static bool IsStatsRedundant( + Operation* op, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has FixedOutputRangeInterface, no need to manually create spec. + return isa(op) || + op_quant_scale_spec_getter(op)->has_fixed_output_range; +} + +static bool IsSameScaleOp( + Operation* op, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has SameScalesOpInterface, no need to manually create spec. + return dyn_cast(op) || + op_quant_scale_spec_getter(op)->has_same_scale_requirement; +} + +bool RemoveRedundantStatsOps( + func::FuncOp func, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + SmallVector all_stats_ops; + llvm::DenseSet redundant_stats_ops; + + // Step 0: remove the quantfork::StatisticsOp which are used by the + // quant.qcast op in case it overrides the information from training FakeQuant + // ops. + func.walk([&](quantfork::QuantizeCastOp q) { + auto input_op = q.getArg().getDefiningOp(); + if (auto stats = dyn_cast_or_null(input_op)) { + q.setOperand(stats.getArg()); + if (stats.use_empty()) stats.erase(); + } + }); + + // Step 1: forward pass: propagate any value scales which are not produces + // by `SameOperandsAndResultsScale`. Additionally, remove the value scales + // which are produced by the ops with the `FixedOutputRangeInterface`. + // Note that we don't propagate across the multiple-operands + // `SameOperandsAndResultsScale` ops like `concatenation`. + func.walk([&](quantfork::StatisticsOp stats_op) { + all_stats_ops.push_back(stats_op); + }); + + while (!all_stats_ops.empty()) { + quantfork::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (auto def = stats_op.getArg().getDefiningOp()) { + if (IsStatsRedundant(def, op_quant_spec_getter, + op_quant_scale_spec_getter)) { + redundant_stats_ops.insert(stats_op); + } + } + + for (Operation* user : stats_op.getResult().getUsers()) { + // We don't propagate this parameter down if it has multiple operands. + // We want to use the result parameter scales instead. + if (!IsSameScaleOp(user, op_quant_scale_spec_getter) || + PreferResultScale(user)) { + continue; + } + for (Value res : user->getResults()) { + if (!res.hasOneUse()) { + continue; + } + if (auto next_stats = + dyn_cast(*res.getUsers().begin())) { + // quantization parameters can be propagated to next_stats + redundant_stats_ops.insert(next_stats); + // add next_stats to the work list so propagation can continue. + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step 2: backward pass: For the ops skipped in the forward pass, propagate + // its results scale backwards as far as possible. + func.walk([&](quantfork::StatisticsOp stats_op) { + if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) { + all_stats_ops.push_back(stats_op); + } + }); + + while (!all_stats_ops.empty()) { + quantfork::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (Operation* def = stats_op.getArg().getDefiningOp()) { + if (!IsSameScaleOp(def, op_quant_scale_spec_getter)) { + continue; + } + for (Value input : def->getOperands()) { + if (auto next_stats = dyn_cast_or_null( + input.getDefiningOp())) { + redundant_stats_ops.insert(next_stats); + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step3: Remove all the redundant stats ops + for (Operation* it : redundant_stats_ops) { + if (!isa(it)) return true; + auto stats_op = cast(it); + stats_op.getResult().replaceAllUsesWith(stats_op.getArg()); + stats_op.erase(); + } + + // Returns false if the steps finish without errors. + return false; +} + +LogicalResult VerifySameScales(Operation* op) { + auto same_scale_op = cast(op); + + SmallVector collected_quant_params; + for (Value input : op->getOperands()) { + QuantizedType quant_params = + QuantizedType::getQuantizedElementType(input.getType()); + // Skip non-quantizable operands. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + for (Value output : op->getResults()) { + const QuantizedType quant_params = + QuantizedType::getQuantizedElementType(output.getType()); + // Skip non-quantizable results. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + if (collected_quant_params.size() <= 1) return success(); + const auto& expected_params = collected_quant_params[0]; + for (int i = 1; i < collected_quant_params.size(); ++i) { + const auto& compared_params = collected_quant_params[i]; + // For some ops (such as Transpose or Squeeze), the quantized axis might not + // be the same, this function only verifies the scale and zero point in + // that case. The quantized axis should be verified in their own verifier + // method. + if (!same_scale_op.RequiredSameQuantizedAxes()) { + const auto expected_per_axis_qtype = + dyn_cast(expected_params); + const auto compared_per_axis_qtype = + dyn_cast(compared_params); + if (expected_per_axis_qtype && compared_per_axis_qtype && + llvm::equal(expected_per_axis_qtype.getScales(), + compared_per_axis_qtype.getScales()) && + llvm::equal(expected_per_axis_qtype.getZeroPoints(), + compared_per_axis_qtype.getZeroPoints()) && + expected_params.getStorageType() == + compared_params.getStorageType() && + expected_params.getExpressedType() == + compared_params.getExpressedType()) { + continue; + } + } + // Same quantization parameters are always ok. + if (expected_params == compared_params) continue; + // If the quantization parameters are not the same, as long as it has the + // same storage type and the op interface doesn't require same scale + // constraint for this storage type, it is still ok. + if (expected_params.isSigned() == compared_params.isSigned() && + expected_params.getStorageTypeIntegralWidth() == + compared_params.getStorageTypeIntegralWidth() && + !same_scale_op.RequiredSameOperandsAndResultsScale( + expected_params.isSigned(), + expected_params.getStorageTypeIntegralWidth())) + continue; + + std::string err_msg = + "quantization parameters violate the same scale constraint: "; + llvm::raw_string_ostream os(err_msg); + expected_params.print(os); + os << " vs. "; + compared_params.print(os); + os.flush(); + return op->emitOpError(err_msg); + } + return success(); +} + +quant::UniformQuantizedType GetFixedOutputRange( + const bool is_signed, const int bit_width, const Type tensor_type, + const double scale, int64_t zero_point, int64_t storage_min, + int64_t storage_max) { + const auto result_type = cast(tensor_type); + if (!isa(result_type.getElementType())) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits and 16-bits + if (bit_width != 8 && bit_width != 16) return {}; + const IntegerType storage_type = builder.getIntegerType(bit_width); + if (!is_signed && bit_width == 8) { + zero_point += 128; + storage_min += 128; + storage_max += 128; + } + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), is_signed, storage_type, + result_type.getElementType(), scale, zero_point, storage_min, + storage_max); +} + +quant::UniformQuantizedType GetFixedOutputRange(const bool is_signed, + const int bit_width, + const Type tensor_type, + const double scale, + const int64_t zero_point) { + return GetFixedOutputRange(is_signed, bit_width, tensor_type, scale, + zero_point, + /*storage_min=*/-(1 << (bit_width - 1)), + /*storage_max=*/(1 << (bit_width - 1)) - 1); +} + +Type ConvertSignedQuantizedToUnsigned(const Type signed_tensor_type, + const Location loc) { + const auto qtype = QType::getQuantizedElementType(signed_tensor_type); + if (!qtype || !qtype.isSigned()) return {}; + + const int num_bits = qtype.getStorageTypeIntegralWidth(); + // This is a negative value, and will be applied on zero points and fixed + // point ranges. + const int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits); + + const auto flags = !quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = dyn_cast(qtype)) { + new_qtype = quant::UniformQuantizedType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = + dyn_cast(qtype)) { + const auto zero_points = aqtype.getZeroPoints(); + SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0; i < new_zero_points.size(); ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } + return new_qtype.castFromExpressedType( + QType::castToExpressedType(signed_tensor_type)); +} + +LogicalResult RemoveDebugAttrPattern::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + // removeAttr will return nullptr if the attribute did not exist. Thus we can + // return success(result) to indicate if this op has changed. + return success(/*isSuccess=*/ + op->removeAttr(kDebugModeOpQuantAttrName) || + op->removeAttr(kDebugModeOpFloatAttrName)); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h new file mode 100644 index 00000000000000..4bac179aff06fb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h @@ -0,0 +1,974 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace TFL { + +// A unit attribute can be attached to the quantize/dequantize ops which are +// added by the quantization passes. These ops can be removed erased without +// losing accuracy. +inline constexpr char kVolatileOpAttrName[] = "volatile"; + +// Following attributes are used to mark ops that are not quantizable during +// debug model generation process for whole-model verify mode. If these +// attributes are attached, the upstream float/quantized ops know which ops to +// connect to, and it also prevents these ops from being copied again. +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; + +// Used to annotate custom ops if they are quantizable. +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; +inline constexpr char kOutputQuantized[] = "_output_quantized"; + +inline constexpr double kNearZeroTolerance = 1.0e-6; + +using QuantParams = QuantizedType; +using QuantSpec = QuantizationSpecs; +using SignedInteger = std::pair; // bitwidth and sign +using QuantParamsForResults = llvm::SmallVector; +using AccumulatorScaleFunc = + std::function&, int, bool)>; +using BiasParamsMap = + absl::flat_hash_map, AccumulatorScaleFunc>>; +// UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) +using GetFixedOutputRangeFunc = std::function; +// bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) +using RequiredSameOperandsAndResultsScaleFunc = std::function; +// bool RequiredSameQuantizedAxes() +using RequiredSameQuantizedAxesFunc = std::function; + +using CustomMap = CustomOpMap; +using Operation = ::mlir::Operation; + +// Quantization spec of an op, driving the quantization algorithm. +struct OpQuantSpec { + // Maps the operand index of a bias input to its quantization specifications, + // including the non-bias operand indexes and the method retrieving + // quantization parameters from list of parameters of the non-bias operands. + // This map is empty if the op doesn't have a bias operand. + BiasParamsMap biases_params; + + // Quantization parameters for value restricted outputs. This is the + // "hard-coded" parameters and should be used unconditionally for the + // quantized op. This vector is empty if the op doesn't have value restricted + // outputs. + llvm::DenseMap restricted_output_params; + + // Coefficient operand index and whether supporting per-channel quantization. + // For QAT, this information is carried by the FakeQuant*/Quantize/Dequantize + // ops, but post-training quantization, the quantization parameters need to be + // inferred from the tensor content and op property. A "-1" value indicates + // the operand doesn't support per-channel quantization. + llvm::DenseMap coeff_op_quant_dim; + + // Indices of quantizable operands. Biases are not included in this field, + // the indices of biases can be found in the `biases_params`. + absl::flat_hash_set quantizable_operands; +}; + +// A function signature for getting the particular OpQuantSpec for the provided +// op. +using OpQuantSpecGetter = + std::function(mlir::Operation*)>; + +// Quantization scale spec of an op. The information defined in the MLIR +// interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should +// be checked first if present. +// TODO: b/323478683: Consider deprecating this. +struct OpQuantScaleSpec { + // Whether this op has a fixed range requirement (e.g. sigmoid) + bool has_fixed_output_range = false; + // Whether this op should have same operand and result scales (e.g. concat) + bool has_same_scale_requirement = false; + // Whether this op should have same operand and result type (e.g. gather) + bool has_same_operand_and_result_type_requirement = false; + // Returns the fixed output range, when has_fixed_output_range is set. + GetFixedOutputRangeFunc fixed_output_range_func; + // Returns whether same operands and results scales are required. + RequiredSameOperandsAndResultsScaleFunc required_same_scale_func = + [](bool sign, int bit_width) { return true; }; + // Returns whether operands and results must have the same quantized axis. + RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() { + return true; + }; +}; + +// A function signature for getting the particular OpQuantScaleSpec for the +// provided op. +using OpQuantScaleSpecGetter = + std::function(mlir::Operation*)>; + +// Used in TFL Numeric Verify +struct NumericVerifySpec { + // Whether to enable numeric verification + bool verify_numeric = false; + + // Tolerance level from the quantized value for verification. If the tolerance + // is very small(<0.1), only the stats of the diff is displayed. + float error_tolerance = 5.0f; + + // Whether to verify numerical correctness layer by layer or by whole model + bool whole_model_verify = false; + + // Whether to enable log for failures + bool log_if_failed_flag = false; +}; + +// Used in TFL Quantize Pass +struct QuantPassSpec { + // Variables to control TFL Numeric Verify + NumericVerifySpec numeric_verify_spec; + + // Variables related to quantization + QuantSpec quant_spec; +}; + +// Re-calculates scales again in float instead of simply downcasting existing +// scales. +quant::QuantizedType DownCastScale(quant::QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc); + +quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, + double max, Location loc); + +bool IsOpQuantizable(mlir::Operation* op); +bool QuantizableOpSupportsFloatOutputType(mlir::Operation* op); + +// Specialized version of location to string for flatbuffer exported locations. +inline std::string GetTensorNameFromLoc(Location loc) { + if (auto name_loc = loc.dyn_cast()) { + return name_loc.getName().str(); + } + return ""; +} + +template +struct ConvertStatsToQDQs : public OpRewritePattern { + ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, + bool legacy_float_scale, MLIRContext* context) + : OpRewritePattern(context), + num_bits(num_bits), + narrow_range(narrow_range), + is_signed(is_signed), + legacy_float_scale(legacy_float_scale) {} + + LogicalResult matchAndRewrite(quantfork::StatisticsOp op, + PatternRewriter& rewriter) const override { + Type expressed = op.getType().cast().getElementType(); + quant::QuantizedType quant_type; + SmallVector mins, maxs; + + if (op.getAxisStats().has_value()) { + // Per axis quantization (or per channel quantization) + int stats_num = op.getAxisStats()->getNumElements(); + if (stats_num == 0 || stats_num % 2 != 0) return failure(); + auto stats = op.getAxisStats()->dyn_cast(); + if (!stats) return failure(); + + for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { + double rmin = FloatAttr::getValueAsDouble(*it++); + double rmax = FloatAttr::getValueAsDouble(*it); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer + // supports only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + mins.push_back(rmin); + maxs.push_back(rmax); + } + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range, + expressed, is_signed); + if (legacy_float_scale) { + quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); + } + } else if (auto stats = + op.getLayerStats().dyn_cast()) { + // Per tensor quantization + auto statValues = stats.getValues(); + double rmin = FloatAttr::getValueAsDouble(statValues[0]); + double rmax = FloatAttr::getValueAsDouble(statValues[1]); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer supports + // only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + quant_type = + quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, + narrow_range, expressed, is_signed); + if (legacy_float_scale) { + quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc()); + } + } else { + return failure(); + } + + rewriter.setInsertionPointAfter(op.getOperation()); + Type result_type = quant_type.castFromExpressedType(op.getType()); + auto q = + rewriter.create(op.getLoc(), result_type, op.getArg()); + q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + + auto dq = rewriter.create(op.getLoc(), op.getType(), q); + op.getResult().replaceAllUsesWith(dq); + q.getOperation()->replaceUsesOfWith(dq, op.getArg()); + op.erase(); + + return success(); + } + + private: + int num_bits; + bool narrow_range; + bool is_signed; + bool legacy_float_scale; + + // Emits an op warning message if the calibrated range is larger than 10.0 and + // the storage type is less than or equal to 8 bits. + void TensorRangeSanityCheck(quantfork::StatisticsOp op, double& min, + double& max) const { + double range = std::fabs(max - min); + if (num_bits <= 8 && range >= 10.0) { + op.emitWarning() + << "Tensor range is too wide to be quantized. Use tf.clip_by_value " + "or tf.relu6 to narrow the tensor range. Range: " + << range << ", bit width: " << num_bits; + } + if (std::abs(max - min) < kNearZeroTolerance) { + op.emitWarning() << "Tensor range (" << min << ", " << max + << ") is too narrow and it might cause overflow. " + "Expanding range symmetrically by " + << kNearZeroTolerance; + min -= kNearZeroTolerance; + max += kNearZeroTolerance; + } + } +}; + +template +bool UsedBy(mlir::Operation* op) { + for (mlir::Operation* user : op->getUsers()) { + if (llvm::isa_and_nonnull(user)) return true; + } + return false; +} + +template +void CreateVerifier(mlir::Operation* quantizing_op, + mlir::Operation* quantized_op, PatternRewriter& rewriter, + int result_idx, const QuantPassSpec& quant_params) { + rewriter.setInsertionPointAfter(quantized_op); + FloatAttr tolerance = rewriter.getF32FloatAttr( + quant_params.numeric_verify_spec.error_tolerance); + BoolAttr log = + rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag); + // Verify the quantized value by sending the result to the verifier. + rewriter.create( + quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(), + quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx), + tolerance, log); +} + +template <> +inline bool UsedBy(mlir::Operation* op) { + return false; +} + +// This specialization is not going to be called, but needed for compilation. +template <> +inline void CreateVerifier(mlir::Operation* quantizing_op, + mlir::Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) {} + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation *) const +// bool AllowDynamicRangeQuantizedResult(Operation *) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +template +class QuantizationPattern : public RewritePattern { + public: + using BaseType = QuantizationPattern; + + explicit QuantizationPattern(MLIRContext* context, + const QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + LogicalResult matchAndRewrite(mlir::Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of this QuantizeOp has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (mlir::Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + tensorflow::DataType inference_type = + quant_params_.quant_spec.inference_type; + bool weight_only_quantization = + quant_params_.quant_spec.weight_only_quantization; + bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; + bool enable_whole_model_verify = + quant_params_.numeric_verify_spec.whole_model_verify; + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (mlir::Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, not quantizable or any ops from the mlir quant + // ops dialect, we shouldn't rewrite. In case of whole-model verify debug + // mode, not-quantizable ops should be duplicated to keep parallel + // float/quant model execution. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizable(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + quantizing_op, custom_map)) { + if (!(enable_verify && enable_whole_model_verify)) { + return failure(); + } + if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) || + quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) { + return failure(); + } + + rewriter.setInsertionPoint(quantizing_op); + mlir::Operation* float_op = rewriter.clone(*quantizing_op); + quantizing_op->setAttr(kDebugModeOpQuantAttrName, + rewriter.getUnitAttr()); + float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr()); + RewireFloatModelBackbone(quantizing_op, float_op); + return success(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // An op with float inputs and outputs are expected when it's used by a + // NumericVerify op. Skip this op. + if (enable_verify && UsedBy(quantizing_op)) { + continue; + } + + bool is_operand_or_result_modified = false; + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (operand_type.isa()) { + inputs.push_back(operand); + continue; + } + + auto ele_type = operand.getType().cast().getElementType(); + if (static_cast(this) + ->AllowDynamicRangeQuantizedOperand(quantizing_op, + custom_map)) { + auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); + + if (dq_op && inference_type == tensorflow::DT_QINT8 && + !static_cast(this)->IsWeightOnlyOp( + quantizing_op, ops_blocklist, weight_only_quantization, + custom_map)) { + // Dynamic range quantization is applied by having QuantizeOp as an + // input. Only int8 weight is supported for now. + inputs.push_back(dq_op.getOperand()); + is_operand_or_result_modified = true; + } else { + // Otherwise, it's the case where the operand is activations or the + // quantizing_op is non-supported/weight-only. + inputs.push_back(operand); + } + } else { + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + is_operand_or_result_modified = true; + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + } + + mlir::Operation* quantized_op; + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, quantizing_op->getResultTypes(), quantizing_op->getAttrs()); + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region* target_region = new_state.addRegion(); + IRMapping mapping; + indexed_regions.value().cloneInto(target_region, mapping); + } + quantized_op = rewriter.create(new_state); + rewriter.replaceOp(quantizing_op, quantized_op); + } else { + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none + // type results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + result.getType().cast().getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + is_operand_or_result_modified = true; + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + // For float16 quantization if none of the operand or result is + // modified, replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + + // To verify the numericals, the original floating-point ops are + // preserved in the graph. The result of these floating-point ops are sent + // to a numeric verifier op as the reference. + if (enable_verify && !std::is_same_v) { + // For constant operands, the floating-point constant is duplicated in + // case it is quantized. + for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) { + auto def = quantized_op->getOperand(i).getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null(def)) { + DenseFPElementsAttr attr; + if (!matchPattern(q.getOperand(), m_Constant(&attr))) { + continue; + } + auto cst = rewriter.create( + quantized_op->getLoc(), attr); + quantizing_op->setOperand(i, cst.getResult()); + } + } + + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!quantizing_op->getResult(i) + .getType() + .cast() + .getElementType() + .isa()) { + continue; + } + CreateVerifier(quantizing_op, quantized_op, rewriter, i, + quant_params_); + + if (enable_whole_model_verify) { + RewireFloatModelBackbone(quantized_op, quantizing_op); + } + } + } + } + return success(); + } + + private: + // Reconnects float ops in the whole-model verify mode. Works for both + // Quantizable ops and Unquantizable ops + void RewireFloatModelBackbone(mlir::Operation* quantized_op, + mlir::Operation* float_op) const { + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!float_op->getResult(i) + .getType() + .cast() + .getElementType() + .isF32()) { + continue; + } + // Find the Quantize/Dequantize users of the new op results, and replace + // the usage. Then all the floating-point ops are connected, forming a + // separate float "backbone" model that the quantized model can be + // compared against in parallel. + // N.B. the return op will use this floating-point result. + Value result; + if (!IsOpQuantizable(float_op)) { + // For not quantizable ops, search for dequantize attached to the + // quantized op of the output. + if (mlir::Operation* quantize_op = dyn_cast_or_null( + *quantized_op->getResult(i).getUsers().begin())) { + result = quantize_op->getResult(0); + } else { + quantized_op->emitError() + << "Output[" << i + << "] is expected to have only one user [QUANTIZE]"; + return; + } + } else { + result = quantized_op->getResult(i); + } + for (auto user : result.getUsers()) { + // Skip the Requantize op and set the user to the following dequantize + // op. This happens when the quantizer tries to match the scale conflict + // with QuantizeOp - QuantizeOp(requant) - DequantizeOp triples. The + // correct float op should be the user of the last DequantizeOp. + if (llvm::isa(user)) { + user = *user->getResult(0).getUsers().begin(); + } + if (auto dequantize = llvm::dyn_cast(user)) { + // Replace all uses, except not quantizable ops that are being used in + // the float backbone. + dequantize.getResult().replaceUsesWithIf( + float_op->getResult(i), [&](OpOperand& use) { + return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName); + }); + } + } + } + } + + QuantPassSpec quant_params_; +}; + +// A pattern that removes debug attributes that are annotated to ops during +// the debug model creation. +class RemoveDebugAttrPattern : public RewritePattern { + public: + explicit RemoveDebugAttrPattern(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(mlir::Operation* op, + PatternRewriter& rewriter) const override; +}; + +// Converts quantized tensor type with signed integer type to quantized tensor +// type with unsigned integer type. +Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc); + +// Converts quantize ops with unsigned quantized types to these with signed +// quantized types and preserves the scales. +template +struct ConvertUnsignedToSigned : public OpRewritePattern { + using BaseType = ConvertUnsignedToSigned; + using QType = quant::QuantizedType; + + explicit ConvertUnsignedToSigned(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(QuantizeOpT op, + PatternRewriter& rewriter) const override { + Type output_type = op.getResult().getType(); + auto qtype = QType::getQuantizedElementType(output_type); + if (!qtype || qtype.isSigned()) return failure(); + + int num_bits = qtype.getStorageTypeIntegralWidth(); + if (num_bits == 8) { + // If storage is 8-bit, trained num bits may be less than 8 so check here. + num_bits = + static_cast(std::ceil(std::log2(qtype.getStorageTypeMax()))); + } + // This is a positive value, and will be applied on zero points and fixed + // point ranges. + int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits); + + auto flags = quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = qtype.template dyn_cast()) { + new_qtype = quant::UniformQuantizedType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = qtype.template dyn_cast< + quant::UniformQuantizedPerAxisType>()) { + auto zero_points = aqtype.getZeroPoints(); + llvm::SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0, e = new_zero_points.size(); i < e; ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } else { + return failure(); + } + + if (!new_qtype) return failure(); + Type new_output_type = new_qtype.castFromExpressedType( + QType::castToExpressedType(output_type)); + rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); + return success(); + } +}; + +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RequantizeOpT op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op->getOperand(0); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + mlir::Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (llvm::isa(def) || + !def->hasTrait()) { + return failure(); + } + + // This op should not clobber def, if more than one requant of this value. + if (!pre_quantized.hasOneUse()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.getResult().getType()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.create(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + +// Converts the min/max/num_bits/narrow_range information to a +// QuantizedType, and then returns the attribute containing the QuantizedType. +// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and +// returns UniformQuantizedType or UniformQuantizedPerAxisType respectively. +// `narrow_range` is set to true for weights and `is_signed` is set to true +// if it is using signed int symmetric quantization. +// +// Note that this method may broadcast min and max to match the dimension length +// of `input_type`, if the `quant_dim` is valid. On the other hand, the +// symmetry of min and max is not adjusted by this method. The QAT workflow +// should set min/max correctly (and use `narrow_range`=true, `is_signed`=true) +// if symmetric quantization is required. +TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, + Attribute max, int quant_dim, + IntegerAttr num_bits, BoolAttr narrow_range, + bool is_signed, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Casts the `target` type to a quantized type by using the quantization +// parameters from the type in the `source` type attribute. +// Examples: +// f32 -> !quant.uniform +// tensor<4xf32> -> tensor<4x!quant.uniform> +// The result is wrapped by a type attribute. Returns nullptr if the cast +// isn't valid. +// +// `axis` is to specify the quantization dimension in the `target` and only +// used if the element type of `source` is a per-channel quantized type. During +// the casting, the quantization dimension of the result type needs to be set +// this new `axis` value. +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target, + int axis); + +// Quantizes the elements in the attribute `real_value` by the quantization +// parameters in `tensor_type`. Returns empty Attribute if the +// `tensor_type` is not a QuantizedType or the quantization fails. +ElementsAttr Quantize(Attribute real_value, Type tensor_type); + +// Quantizes the elements in "legacy mode", where it calls TOCO's methods to +// to quantize values with float scale. +ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type); + +// Returns the quantized type for an element attribute. The quantization +// parameters in this type is based on the min and max element of the +// attribute. When the elements in the `attr` are not in floating-point, or +// the value range isn't straddling zero, an empty type is returned. The min/max +// are adjusted to be symmetric if `symmetric` flag is set to True. And +// `symmetric` can only be set to true when it is signed and narrow_range. +Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, + unsigned num_bits, bool is_signed, + bool narrow_range, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the per channel quantized type for an element attribute. +// `quant_dim` defines the quantization axis. The channel min/max are adjusted +// to be symmetric if `symmetric` flag is set to True. And `symmetric` can only +// be set to true when it is signed and narrow_range. +Type GetUniformQuantizedPerAxisTypeForWeight( + ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits, + bool is_signed, bool narrow_range, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the quantized type of a bias input, given the quantized types of +// other operands which are multiply-accumulated (the bias is added to the +// accumulated value). +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, int adjusted_quant_dim, + bool legacy_float_scale = false); + +// Gets quantization scale specs (e.g. fixed output range, same result and +// operand scales) from the default quantization interfaces. The op should +// outlive returned spec for its interface methods to be properly referenced. +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op); + +// The function might contain more stats ops than required, and it will +// introduce requantize if the calibration stats have conflicts. This method +// tries to remove all the redundant stats ops. +bool RemoveRedundantStatsOps(mlir::func::FuncOp func, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter = + GetDefaultQuantScaleSpec); + +// Given quantization parameters for int8, compute the quantization parameters +// for uint if it is required, and wrap the result in an UniformQuantizedType. +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min, + int64_t storage_max); + +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point); + +// Extracts min and max values from the DenseFPElementsAttr, and stores them +// into `mins` and `maxs`. When mins and maxs are extracted per-channel, +// `dim_size` is number of channels and `slice_size` is the size of slice per +// each channel. When `symmetric` is true, the range is expanded to [-M, M]. +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs); + +// Returns the quantized type for the +// input_type/min/max/storage_type_width/narrow_range. +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, bool is_signed, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 5bfa389e92a08a..a86467e79d4017 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -686,7 +686,7 @@ static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value, // Function `_copy_to_device`. static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, PyObject* kwds) { - if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr; + if (!PyArg_NoKeywords("copy_to_device", kwds)) return nullptr; const char* device_name = nullptr; if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName, @@ -874,7 +874,7 @@ static int EagerTensor_traverse(PyObject* self, visitproc visit, void* arg) { PyObject*& dict = *_PyObject_GetDictPtr(self); Py_VISIT(dict); #else - _PyObject_VisitManagedDict(self, visit, arg); + PyObject_VisitManagedDict(self, visit, arg); #endif // PY_VERSION_HEX < 0x030C0000 Py_VISIT(((EagerTensor*)self)->handle_data); Py_VISIT(((EagerTensor*)self)->tensor_shape); @@ -897,7 +897,7 @@ extern int EagerTensor_clear(PyObject* self) { PyObject*& dict = *_PyObject_GetDictPtr(self); Py_CLEAR(dict); #else - _PyObject_ClearManagedDict(self); + PyObject_ClearManagedDict(self); #endif // PY_VERSION_HEX < 0x030C0000 Py_CLEAR(((EagerTensor*)self)->handle_data); diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl index 47fe64d7b7b039..2af9b29d7af20b 100644 --- a/tensorflow/tools/toolchains/python/python_repo.bzl +++ b/tensorflow/tools/toolchains/python/python_repo.bzl @@ -7,7 +7,7 @@ Defaults to 3.10. To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ -VERSIONS = ["3.9", "3.10", "3.11", "3.12"] +VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] DEFAULT_VERSION = "3.11" WARNING = """ TF_PYTHON_VERSION environment variable was not set correctly; using Python {}.