From 49f6239f513d74e10c19569cbb2bc19e02b6a190 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Mon, 26 Jan 2026 23:34:04 -0800 Subject: [PATCH 1/2] Add group processing tests to FiLM functionality - Introduced two new test functions: `test_process_with_groups` and `test_process_with_groups_scale_only` to validate the behavior of the FiLM class with grouped convolution parameters. - Updated `run_tests.cpp` to include these new tests, ensuring comprehensive coverage of group processing scenarios. - Enhanced the FiLM constructor to accept a groups parameter, allowing for flexible configuration of grouped convolutions in the processing logic. - Adjusted relevant header files to support the new groups parameter in the FiLM class and its associated structures. --- NAM/film.h | 5 +- NAM/wavenet.cpp | 3 +- NAM/wavenet.h | 36 ++++++----- tools/run_tests.cpp | 2 + tools/test/test_film.cpp | 131 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 17 deletions(-) diff --git a/NAM/film.h b/NAM/film.h index 9e1ec25..f0f86fb 100644 --- a/NAM/film.h +++ b/NAM/film.h @@ -23,8 +23,9 @@ class FiLM /// \param condition_dim Size of the conditioning input /// \param input_dim Size of the input to be modulated /// \param shift Whether to apply both scale and shift (true) or only scale (false) - FiLM(const int condition_dim, const int input_dim, const bool shift) - : _cond_to_scale_shift(condition_dim, (shift ? 2 : 1) * input_dim, /*bias=*/true) + /// \param groups Number of groups for grouped convolution in the condition-to-scale-shift submodule (default: 1) + FiLM(const int condition_dim, const int input_dim, const bool shift, const int groups = 1) + : _cond_to_scale_shift(condition_dim, (shift ? 2 : 1) * input_dim, /*bias=*/true, groups) , _do_shift(shift) { } diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 0d7fa5c..5d2ae99 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -760,7 +760,8 @@ std::unique_ptr nam::wavenet::Factory(const nlohmann::json& config, st const nlohmann::json& film_config = layer_config[key]; bool active = film_config.value("active", true); bool shift = film_config.value("shift", true); - return nam::wavenet::_FiLMParams(active, shift); + int groups = film_config.value("groups", 1); + return nam::wavenet::_FiLMParams(active, shift, groups); }; // Parse FiLM parameters diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 2cf8be7..e324849 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -67,13 +67,16 @@ struct _FiLMParams /// \brief Constructor /// \param active_ Whether FiLM is active at this location /// \param shift_ Whether to apply both scale and shift (true) or only scale (false) - _FiLMParams(bool active_, bool shift_) + /// \param groups_ Number of groups for grouped convolution in the condition-to-scale-shift submodule (default: 1) + _FiLMParams(bool active_, bool shift_, int groups_ = 1) : active(active_) , shift(shift_) + , groups(groups_) { } const bool active; ///< Whether FiLM is active const bool shift; ///< Whether to apply shift in addition to scale + const int groups; ///< Number of groups for grouped convolution in the condition-to-scale-shift submodule }; /// \brief Parameters for constructing a single Layer @@ -215,48 +218,53 @@ class _Layer // Initialize FiLM objects if (params.conv_pre_film_params.active) { - _conv_pre_film = - std::make_unique(params.condition_size, params.channels, params.conv_pre_film_params.shift); + _conv_pre_film = std::make_unique( + params.condition_size, params.channels, params.conv_pre_film_params.shift, params.conv_pre_film_params.groups); } if (params.conv_post_film_params.active) { const int conv_out_channels = (params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck; - _conv_post_film = - std::make_unique(params.condition_size, conv_out_channels, params.conv_post_film_params.shift); + _conv_post_film = std::make_unique(params.condition_size, conv_out_channels, + params.conv_post_film_params.shift, params.conv_post_film_params.groups); } if (params.input_mixin_pre_film_params.active) { _input_mixin_pre_film = - std::make_unique(params.condition_size, params.condition_size, params.input_mixin_pre_film_params.shift); + std::make_unique(params.condition_size, params.condition_size, params.input_mixin_pre_film_params.shift, + params.input_mixin_pre_film_params.groups); } if (params.input_mixin_post_film_params.active) { const int input_mixin_out_channels = (params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck; - _input_mixin_post_film = std::make_unique( - params.condition_size, input_mixin_out_channels, params.input_mixin_post_film_params.shift); + _input_mixin_post_film = + std::make_unique(params.condition_size, input_mixin_out_channels, + params.input_mixin_post_film_params.shift, params.input_mixin_post_film_params.groups); } if (params.activation_pre_film_params.active) { const int z_channels = (params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck; _activation_pre_film = - std::make_unique(params.condition_size, z_channels, params.activation_pre_film_params.shift); + std::make_unique(params.condition_size, z_channels, params.activation_pre_film_params.shift, + params.activation_pre_film_params.groups); } if (params.activation_post_film_params.active) { _activation_post_film = - std::make_unique(params.condition_size, params.bottleneck, params.activation_post_film_params.shift); + std::make_unique(params.condition_size, params.bottleneck, params.activation_post_film_params.shift, + params.activation_post_film_params.groups); } if (params._1x1_post_film_params.active) { - _1x1_post_film = - std::make_unique(params.condition_size, params.channels, params._1x1_post_film_params.shift); + _1x1_post_film = std::make_unique(params.condition_size, params.channels, + params._1x1_post_film_params.shift, params._1x1_post_film_params.groups); } if (params.head1x1_post_film_params.active && params.head1x1_params.active) { - _head1x1_post_film = std::make_unique( - params.condition_size, params.head1x1_params.out_channels, params.head1x1_post_film_params.shift); + _head1x1_post_film = + std::make_unique(params.condition_size, params.head1x1_params.out_channels, + params.head1x1_post_film_params.shift, params.head1x1_post_film_params.groups); } }; diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 0e5ad65..be65760 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -125,6 +125,8 @@ int main() test_film::test_process_inplace_with_shift(); test_film::test_process_inplace_scale_only(); test_film::test_process_inplace_partial_frames(); + test_film::test_process_with_groups(); + test_film::test_process_with_groups_scale_only(); test_film_realtime_safe::test_allocation_tracking_pass(); test_film_realtime_safe::test_allocation_tracking_fail(); diff --git a/tools/test/test_film.cpp b/tools/test/test_film.cpp index ec8bb95..57f0acd 100644 --- a/tools/test/test_film.cpp +++ b/tools/test/test_film.cpp @@ -346,4 +346,135 @@ void test_process_inplace_partial_frames() } } } + +void test_process_with_groups() +{ + // Test FiLM with groups parameter + // Using groups=2, condition_dim=4, input_dim=4 + // Internal Conv1x1: condition_dim=4 -> 2*input_dim=8 (with shift) + // With groups=2: + // - Group 0: processes condition[0:1] -> output[0:3] (out_per_group = 8/2 = 4) + // - Group 1: processes condition[2:3] -> output[4:7] + // Output layout: [scale[0:3], shift[0:3]] + // So Group 0 produces: scale[0:1] and shift[0:1] (first 2 of each) + // Group 1 produces: scale[2:3] and shift[2:3] (last 2 of each) + const int condition_dim = 4; + const int input_dim = 4; + const int groups = 2; + const bool shift = true; + nam::FiLM film(condition_dim, input_dim, shift, groups); + + const int maxBufferSize = 64; + film.SetMaxBufferSize(maxBufferSize); + + // Weight layout for grouped Conv1x1: + // For groups=2, out_channels=8, in_channels=4: + // - out_per_group = 4, in_per_group = 2 + // - Group 0: weight[0:3, 0:1] = 4*2 = 8 weights + // - Group 1: weight[4:7, 2:3] = 4*2 = 8 weights + // - Bias: 8 values + // Weight vector layout: [group0 weights (row-major), group1 weights (row-major), biases] + std::vector weights; + weights.resize(8 * 2 + 8, 0.0f); // 2 groups * 8 weights + 8 biases + + // Set biases to create distinct scale/shift values + // Output indices: [0:3] = scale[0:3], [4:7] = shift[0:3] + const int bias_offset = 8 * 2; + weights[bias_offset + 0] = 2.0f; // scale[0] (from group 0) + weights[bias_offset + 1] = -1.0f; // scale[1] (from group 0) + weights[bias_offset + 2] = 0.5f; // scale[2] (from group 1) + weights[bias_offset + 3] = 3.0f; // scale[3] (from group 1) + weights[bias_offset + 4] = 10.0f; // shift[0] (from group 0) + weights[bias_offset + 5] = -20.0f; // shift[1] (from group 0) + weights[bias_offset + 6] = 5.0f; // shift[2] (from group 1) + weights[bias_offset + 7] = -15.0f; // shift[3] (from group 1) + + auto it = weights.begin(); + film.set_weights_(it); + assert(it == weights.end()); + + const int num_frames = 3; + Eigen::MatrixXf input(input_dim, num_frames); + // Make each channel distinct + input << 1.0f, 2.0f, 3.0f, // + 2.0f, 4.0f, 6.0f, // + 3.0f, 6.0f, 9.0f, // + 4.0f, 8.0f, 12.0f; + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); // doesn't matter because weights are zero + + film.Process(input, condition, num_frames); + const auto out = film.GetOutput().leftCols(num_frames); + + // Expected: output = input * scale + shift (elementwise) + const float scales[4] = {2.0f, -1.0f, 0.5f, 3.0f}; + const float shifts[4] = {10.0f, -20.0f, 5.0f, -15.0f}; + for (int c = 0; c < input_dim; c++) + { + for (int t = 0; t < num_frames; t++) + { + const float expected = input(c, t) * scales[c] + shifts[c]; + assert(std::abs(out(c, t) - expected) < 1e-6f); + } + } +} + +void test_process_with_groups_scale_only() +{ + // Test FiLM with groups parameter in scale-only mode + const int condition_dim = 4; + const int input_dim = 4; + const int groups = 2; + const bool shift = false; + nam::FiLM film(condition_dim, input_dim, shift, groups); + + const int maxBufferSize = 64; + film.SetMaxBufferSize(maxBufferSize); + + // Internal Conv1x1: condition_dim=4 -> input_dim=4 (scale-only) + // With groups=2: + // - Group 0: condition[0:1] -> scale[0:1] + // - Group 1: condition[2:3] -> scale[2:3] + // Each group processes 2 input channels -> 2 output channels + // Weight layout: [group0 weights (2x2), group1 weights (2x2), biases (4)] + + std::vector weights; + weights.resize((2 * 2) * 2 + 4, 0.0f); // 2 groups * (2 outputs * 2 inputs) + 4 biases + + // Set biases to create distinct scale values per group + const int bias_offset = (2 * 2) * 2; + weights[bias_offset + 0] = 2.0f; // scale[0] + weights[bias_offset + 1] = -1.0f; // scale[1] + weights[bias_offset + 2] = 0.5f; // scale[2] + weights[bias_offset + 3] = 3.0f; // scale[3] + + auto it = weights.begin(); + film.set_weights_(it); + assert(it == weights.end()); + + const int num_frames = 2; + Eigen::MatrixXf input(input_dim, num_frames); + input << 1.0f, 2.0f, // + 2.0f, 4.0f, // + 3.0f, 6.0f, // + 4.0f, 8.0f; + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); // doesn't matter because weights are zero + + film.Process(input, condition, num_frames); + const auto out = film.GetOutput().leftCols(num_frames); + + // Expected: output = input * scale (elementwise) + const float scales[4] = {2.0f, -1.0f, 0.5f, 3.0f}; + for (int c = 0; c < input_dim; c++) + { + for (int t = 0; t < num_frames; t++) + { + const float expected = input(c, t) * scales[c]; + assert(std::abs(out(c, t) - expected) < 1e-6f); + } + } +} } // namespace test_film From da4ea04d1358249cf62dec2fcdb6803b9e99969d Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Mon, 26 Jan 2026 23:45:08 -0800 Subject: [PATCH 2/2] update wavenet_a2_max.nam --- example_models/wavenet_a2_max.nam | 764 ++---------------------------- generate_weights_a2.py | 70 ++- 2 files changed, 105 insertions(+), 729 deletions(-) diff --git a/example_models/wavenet_a2_max.nam b/example_models/wavenet_a2_max.nam index 117e71e..e820526 100644 --- a/example_models/wavenet_a2_max.nam +++ b/example_models/wavenet_a2_max.nam @@ -66,35 +66,43 @@ "secondary_activation": "Hardswish", "conv_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "conv_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "input_mixin_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "input_mixin_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "activation_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "activation_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "1x1_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "head1x1_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 } }, { @@ -155,35 +163,43 @@ ], "conv_pre_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "conv_post_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "input_mixin_pre_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "input_mixin_post_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "activation_pre_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "activation_post_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "1x1_post_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 }, "head1x1_post_film": { "active": true, - "shift": false + "shift": false, + "groups": 1 } } ], @@ -1242,7 +1258,7 @@ 0.8908312859402125, -0.9132699262016566, 0.5664539919607592, - 0.01 + 0.7339618155196765 ], "sample_rate": 48000 }, @@ -1274,35 +1290,43 @@ "secondary_activation": "", "conv_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 2 }, "conv_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 4 }, "input_mixin_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 4 }, "input_mixin_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 2 }, "activation_pre_film": { "active": true, - "shift": true + "shift": true, + "groups": 1 }, "activation_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 2 }, "1x1_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 8 }, "head1x1_post_film": { "active": true, - "shift": true + "shift": true, + "groups": 4 } } ], @@ -2127,695 +2151,7 @@ 0.14359528396980226, -0.5328751078842138, 0.5510889501985177, - -0.9127054018053962, - 0.4894103031303918, - 0.4104557620500051, - 0.6228178051297537, - -0.22784249501803533, - 0.32737765896906934, - 0.6414951034191101, - 0.9616362773195701, - -0.009342700767152934, - -0.9259607773088476, - 0.00458230026593176, - 0.1803608586191341, - 0.7394006267244309, - 0.7483807481086167, - -0.11938758045864573, - 0.05190217361512284, - -0.08614385105151556, - 0.44488765514128525, - -0.1800427605072712, - 0.3095626528554676, - -0.6912775624569023, - -0.061018798030536114, - 0.9384072611484344, - -0.32287753189713664, - 0.3854091971737641, - 0.2996733051585452, - 0.7035305847013829, - 0.7046826731317861, - 0.7186843683365332, - -0.2399812049546899, - -0.36667769213320067, - 0.437434850445966, - 0.5188036186687273, - 0.7447660347970726, - -0.9282018003246268, - -0.8631585056368971, - 0.2623220337092811, - 0.8418581975605934, - 0.9948518459124307, - 0.4935327334758539, - -0.13205705615374863, - -0.8031137472336041, - 0.26749565756158256, - 0.7451584652141814, - -0.11264289667658844, - 0.3880023255864691, - 0.806848124100972, - -0.9080180626455745, - 0.5922869303244236, - -0.4132644480695824, - -0.25031782044937034, - -0.7088604086012031, - 0.06233263629752006, - 0.13185612383187362, - 0.5850389477619065, - -0.6600327018638052, - -0.8420632986800107, - 0.7416791972897097, - 0.2394207370748307, - -0.5183404162873739, - 0.8256580320471096, - -0.7137645597429807, - -0.07770017329000489, - -0.49204532117291544, - -0.4893465836816919, - -0.9812051370902579, - 0.6092661539502926, - 0.802418847197766, - 0.3552217713981589, - -0.6840487558564261, - -0.11654043279237114, - -0.3088687511138366, - 0.17514341025284286, - 0.2778774047200243, - -0.15138212307820464, - -0.4998035511845975, - 0.6906078502851973, - -0.6015660017822153, - -0.23061350206530906, - -0.0335838778815678, - -0.5255885961448508, - 0.14384538470147779, - 0.14962386035729014, - 0.9853840872537951, - -0.4095384922334613, - 0.9558889691537253, - 0.31645963185736137, - -0.45103923964354453, - 0.13185803391139794, - 0.37159898546802417, - 0.48933768233065034, - -0.9019114984480001, - 0.21281298615294908, - -0.006545426952259348, - 0.8083105817874507, - -0.42761169708161106, - 0.5977202390151974, - 0.2141299963285319, - -0.29535808832936694, - 0.27323575601178374, - 0.24178232626060825, - 0.35552891724671887, - 0.44185675334141794, - 0.3183630806335074, - 0.6766742339250331, - 0.25649620737379664, - 0.8068074081459358, - 0.2926812177810456, - -0.38213423209472785, - -0.11835361967437685, - 0.1591476107368055, - 0.4647195358784766, - -0.8197332485075621, - -0.40977909674168944, - 0.49496172987674303, - -0.6487198591113867, - -0.7356804045064329, - 0.07881551796884212, - 0.9429791624226798, - 0.06170474740557008, - 0.8269739489633938, - 0.6609452391347947, - -0.4860598308734754, - 0.6493796250848145, - -0.03630434025251761, - 0.612976987587533, - 0.4931187014340901, - -0.32256949239622745, - -0.7696605851004812, - 0.9257865857377552, - -0.7184859699882347, - 0.9330004189255041, - 0.7202811937976437, - 0.4484334241510364, - 0.9598844855638069, - 0.9345394946000647, - 0.6091752880411239, - -0.26844990118886725, - 0.5813639371778749, - -0.9721626897981976, - 0.07314461653811821, - -0.09042794453225067, - 0.3456567637475072, - 0.3446815947020263, - 0.1691201833043321, - 0.6448346024535485, - 0.8805837835591084, - -0.7833077956015386, - -0.5323561955768337, - -0.9499507007070354, - 0.7684696904297044, - 0.1228147644997577, - 0.8305118174863189, - -0.55726559985202, - -0.8735659176796091, - 0.6477107027808953, - 0.8187752768557841, - -0.39561965094149953, - -0.18340828840917456, - -0.7204459749855578, - 0.8925230657638925, - -0.39127083128798956, - -0.014750762043588272, - -0.8056160027563568, - 0.7745186170570046, - -0.728671902587327, - -0.09271248622221484, - 0.34097243770034247, - 0.4862802430463431, - 0.8919481715588642, - -0.16174649317055434, - 0.48453802953063163, - -0.69095419518033, - -0.17023094512638615, - -0.801956730578943, - -0.02130592442077095, - -0.18376822860459896, - 0.903043050762119, - -0.9345674262899062, - -0.2589400825311292, - -0.1132338278785967, - 0.9011103397028539, - 0.7109003866119092, - -0.8012907507877394, - 0.3713605309625707, - 0.08893172296428986, - 0.9556850589040935, - -0.28265231757536413, - -0.20372071451125384, - -0.6203828756778409, - -0.7556805618254725, - 0.6960663769273621, - -0.09056526285896571, - 0.325537476123956, - 0.2834089344664352, - 0.19429191903909016, - -0.9572850905272587, - 0.5735891809092335, - -0.5128622056719527, - -0.7481522293839142, - 0.12915595181592665, - -0.8627796943512882, - 0.5303147517771689, - -0.585685259306683, - -0.5680972961626505, - 0.7393908535390894, - -0.34288089313553916, - -0.7048916401171126, - 0.8010620712635164, - -0.9943288970399673, - 0.7168122527603527, - -0.710624039358984, - -0.7400157371131264, - -0.49869160654375233, - -0.6510057581972131, - 0.3221152851946336, - -0.9484397004276033, - -0.9702793455385823, - 0.5799693284695078, - -0.5241367878190739, - -0.3524570760759551, - -0.6515075971876456, - -0.895201964277647, - 0.483436113908299, - 0.05217105319573423, - 0.4913305500679914, - -0.04750806915349948, - 0.5560340786284055, - 0.026475915218383328, - -0.781891979992309, - 0.007677379571642717, - 0.8908312859402125, - -0.9132699262016566, - 0.5664539919607592, - 0.7339618155196765, - 0.042902429426168176, - -0.08391495580463593, - 0.9280523662440574, - -0.8783491850109884, - -0.04203617800327342, - -0.19676549097487905, - 0.3721949921244654, - -0.01946229171154945, - 0.8194016582304586, - -0.853018568467099, - -0.8384190451784161, - 0.21659484726687994, - -0.8686355333597722, - -0.4499680008840936, - 0.26615344860203094, - 0.09671286809672108, - -0.3496291133626641, - 0.9892555117218471, - 0.06111367486264907, - -0.09256916484905675, - 0.21085358307062574, - -0.801643076641916, - 0.40355883709213236, - 0.7055854745911501, - 0.3018333297626108, - 0.5379254602094772, - 0.44167983331519833, - -0.5699538673450062, - -0.09689016806943695, - -0.5430128512708312, - -0.3221367623296496, - -0.09300219418510536, - -0.16802069947700948, - -0.8098283214487383, - -0.146471987478084, - 0.3302157261206178, - -0.2513979531289263, - -0.6947221504625525, - 0.8459700714687688, - -0.8657333837267045, - 0.663543776949709, - -0.8135397965926452, - -0.8068711348684288, - 0.4775919969774314, - 0.6235385705547847, - 0.11274147123070022, - 0.1729301654789539, - 0.12317282798414486, - -0.3407080371675899, - -0.7555374292908814, - -0.2928038407324647, - 0.3306810400058309, - 0.5005685005029565, - 0.7361842977381299, - 0.44212135749229886, - 0.9367972506229489, - 0.20082018244935407, - -0.29670762861370203, - 0.15583703677970973, - -0.5745223886559878, - 0.3134726059763042, - -0.5515102617848688, - -0.7835632361454667, - 0.6907468372026901, - -0.2648778987692997, - 0.5252112638736994, - 0.14820000866292538, - 0.6144427423046888, - 0.6903103226567162, - 0.9490932042514164, - 0.6368537190813393, - 0.22714656107092956, - 0.2853983276596628, - -0.9474923370928332, - 0.8581685819898728, - 0.658921579919326, - -0.4651045496917787, - -0.6391678560782912, - 0.40539754573120934, - -0.3820306223401797, - -0.32035068644548326, - -0.9877884211926826, - 0.7397254130728763, - 0.13264218952275253, - -0.19843131200096886, - -0.7162506916974627, - 0.26634402531102763, - -0.9386858032381882, - 0.49222352401141345, - -0.5697342399329781, - -0.1603350124709937, - -0.31820803646133444, - -0.2598938150459922, - 0.4431919354853464, - 0.553671239933482, - 0.1351871132287945, - -0.8300859200456414, - -0.8947823471489564, - -0.6851802057856937, - 0.23567636385206114, - 0.34793742122622606, - -0.4557943129075621, - 0.32387738561652246, - -0.028676590218007503, - -0.11591162660444176, - -0.4536663112704611, - 0.5098862873367414, - -0.7723649837795965, - -0.1401727332042193, - -0.4335070598391013, - 0.35697250952026605, - -0.02673449327149835, - 0.33426511747269916, - -0.909165274791145, - -0.2094732078247421, - 0.19864991388890085, - -0.9846258282002343, - -0.39716127596312645, - -0.5775320415593754, - -0.725530389485346, - -0.4889609919570814, - -0.34375528812431777, - -0.9845401868613761, - 0.4940282468595356, - -0.6486103962483154, - -0.23958510856952753, - 0.40734252676532723, - 0.000524693112426311, - 0.6667084048397565, - 0.6124003731335277, - -0.8558490068156843, - 0.7235287240451771, - -0.9153954768744172, - -0.9625169268287086, - 0.8423248690048248, - 0.7242200272241328, - 0.15151832147366218, - 0.14679936177168607, - 0.41899792313786843, - -0.16461208031303887, - -0.7696532546724035, - -0.9582868819505024, - -0.3504636411089723, - 0.6026443086209043, - 0.23625052660848045, - 0.6640518261434143, - 0.8395395034827693, - -0.8237402374042306, - 0.6889687196293939, - -0.5133670503545327, - 0.17774257660582382, - 0.04792508600126322, - -0.20846660628133407, - -0.37945087632827734, - -0.3209734377030722, - -0.3338627550137361, - -0.6637345839086268, - 0.020966569084296616, - -0.7719467203228949, - 0.019904124645943932, - 0.8118454631601009, - -0.3012494690552394, - 0.4547582113478086, - 0.6378972030497039, - 0.6300740115000283, - -0.5274623021065514, - -0.7071115634446112, - -0.6054563943520335, - 0.20479797054633164, - 0.5204305910937841, - 0.3110180210374782, - -0.6457077421055417, - 0.5456961784951204, - -0.011765949965222733, - 0.5088916504939716, - 0.5197542992154969, - -0.10218948600987399, - 0.8483085167713187, - 0.12898356680559853, - 0.27059663812118684, - 0.2490435588837896, - 0.7284937496628323, - 0.2544348137995436, - -0.6980851972138462, - -0.8634274830084878, - -0.11558387232743028, - -0.3943591297219742, - -0.45065266502768164, - -0.8876557595738439, - 0.014673770579556322, - -0.37918429878737814, - -0.09617227259863559, - -0.8862198983320952, - 0.6633932633262429, - -0.84653799765309, - 0.7285000678504783, - 0.7105867429121806, - 0.23001677683114718, - 0.014135634667910235, - -0.07457668214545543, - 0.10863274267660783, - 0.583635594530717, - 0.7917535311136052, - -0.10053259269818748, - 0.6196318353959422, - 0.30367490929721974, - -0.3569464742391433, - -0.04874194411386101, - -0.6982778466087229, - -0.8762525997977986, - -0.7929962454592168, - 0.7982536679115471, - -0.31312444832346475, - 0.4286310983349253, - 0.009098003001914767, - -0.6548821771272701, - -0.5045125528034315, - -0.12448345138153405, - -0.1211564164747514, - 0.04549607053109117, - -0.6825075840369172, - -0.25429603579734583, - -0.43421284277116423, - -0.18246120550535938, - -0.3232657063171609, - 0.19577172476597893, - 0.5784538631273295, - 0.2946107139386138, - -0.8681762914405651, - -0.8109881049698113, - 0.356758689680162, - -0.43170605224377256, - 0.44746731000492357, - 0.31312817282095584, - 0.8126853943273544, - 0.7465593241210777, - -0.3332759278788868, - 0.1654790291728392, - -0.7171432388313643, - -0.30035842492831155, - 0.9353930153854892, - 0.3969599256237619, - -0.21608403132927934, - 0.19008245630314957, - 0.8760043991315216, - -0.38083622516807814, - -0.24664138788739964, - 0.5833239157270871, - 0.6263695676296019, - 0.340232799989445, - 0.657917945788929, - 0.47754934425896556, - 0.3708288805551525, - 0.05278667946834803, - 0.29204964146697576, - -0.15318726735579258, - -0.2763438073064308, - -0.27480466203591125, - -0.6394741542463107, - -0.5716146775821993, - 0.8953365350686937, - -0.02745815825454545, - -0.5469139069322255, - -0.7248692936458414, - -0.8456698313982474, - 0.6888567773717666, - -0.7977184726642117, - 0.541749440726202, - 0.670239653269112, - 0.7673643309851232, - -0.9245050152848775, - -0.3264712560714369, - 0.5326152088944831, - -0.737901917140984, - -0.24656025835495066, - -0.675505575823099, - 0.662690113378251, - 0.5421956274619768, - 0.6180874392786022, - -0.6689216685118826, - -0.12465318972330652, - -0.17828277700693307, - 0.3527258443770658, - -0.5249395971061499, - -0.1116025803894487, - -0.4301441348732835, - 0.4970730361908726, - -0.10214407393331926, - 0.06802229939652227, - -0.3810642068744208, - 0.6172477421815727, - -0.061968789935494595, - 0.6702267854514148, - -0.26431808354993436, - 0.8942603404882459, - 0.9688795870631546, - -0.07664004311802164, - -0.43645653459244915, - -0.23625513161857903, - 0.05491957692296534, - 0.9325363064118946, - 0.6337824791624802, - 0.6025184483030952, - -0.7232029307931376, - -0.49999357682198586, - 0.2823580724088943, - 0.748233890104474, - 0.10908149084886243, - -0.7948205365031813, - 0.691784553466877, - 0.7023320961695576, - -0.42987397188134313, - 0.5262336605833817, - -0.45441740081728677, - 0.8106124179564824, - -0.7053026880167914, - -0.12505487961003836, - 0.8928265260235213, - -0.5559239861860388, - -0.09774401955858214, - -0.30082984372270216, - -0.9466596183941229, - -0.8934862256578511, - 0.004014229386470891, - -0.5284438522731347, - 0.9890507024765824, - -0.25017465316446486, - -0.9436249089436999, - 0.8616518094999062, - 0.6783525752232114, - 0.29992136858836327, - 0.582761274964352, - -0.7248008245482567, - -0.42624120537346366, - 0.6595231663056451, - 0.39214397715196725, - -0.7224146163636276, - 0.41107235057816083, - -0.10279705203550682, - -0.9894976033887641, - -0.8415484574585574, - -0.4881521431259106, - 0.669926198564762, - 0.09760849088767087, - 0.4544695706498636, - 0.05554301177344856, - -0.7776262793515232, - -0.42379684392153893, - -0.39769761082756006, - -0.9045011067622974, - -0.1603489124931139, - 0.5877982172788763, - -0.08577276672710266, - -0.77828420941997, - 0.810293771323997, - 0.19347808563789393, - -0.9671292955882111, - 0.03075146041881216, - -0.5161237311580134, - -0.7128463195074799, - -0.14152221379333252, - 0.22961916555190132, - -0.5188715223869129, - -0.16686480958292038, - 0.3287426034840182, - -0.8287720900254765, - 0.9493089819044433, - -0.8646413541823079, - 0.052118890644341054, - 0.014655393159573205, - 0.9766629711929349, - 0.10830390483639096, - -0.2190925348798718, - -0.05972984368302803, - 0.2713415829372601, - 0.9620788451031206, - -0.49269947786157453, - -0.9675155377821048, - 0.5770400325590306, - -0.3103950136732174, - 0.4658820429012833, - 0.25651392495131287, - 0.5430027482197182, - 0.4703739696246225, - -0.33496278325603024, - -0.9113286234095836, - 0.09202749041538305, - 0.6270177311121763, - -0.6498217458965831, - 0.558285186956563, - -0.07075420050592451, - 0.3907785039921279, - 0.2634716955166756, - 0.6229953636952128, - -0.8737989259355508, - 0.5523807994068435, - -0.08464084510529357, - -0.4131148576499375, - -0.9123874486817538, - -0.6010603325696233, - -0.916188116139236, - 0.8667419599007946, - 0.030767178508997572, - 0.9782454045922468, - 0.08606139530837176, - -0.4933724695947652, - 0.5065818376377298, - -0.6177931385321782, - -0.2860516479292732, - 0.56168313395685, - 0.7315965541561151, - -0.33615062723731093, - -0.7510499835112332, - -0.26396165137065397, - 0.7789730340245629, - 0.48661541103924244, - 0.7892749899101066, - -0.22671046347861568, - 0.9474471686306181, - -0.007593546925954708, - -0.004953215012745371, - 0.8486209332539272, - 0.03855170706988398, - 0.6022961748035476, - 0.45416264868527145, - -0.8421459878890643, - 0.2049065976604545, - 0.6446825590797733, - 0.09094879468927397, - -0.35757714357081194, - -0.8398621778500095, - 0.3218384429162733, - -0.38700828781849506, - 0.2052432554611996, - -0.14776785423390915, - 0.37952961689097253, - -0.29690603245601577, - -0.915289674299741, - 0.7400743501127842, - -0.2948813793830354, - 0.9963011955460981, - -0.4508892798502848, - 0.9600545583885582, - 0.8958087572061726, - -0.8499176700236815, - 0.2750250757665966, - -0.2733777386980354, - 0.6021919511243399, - 0.35882121562937974, - 0.9055787925592156, - -0.7144410632749005, - 0.21514580664171068, - 0.562623939486933, - -0.930402068402916 + -0.9127054018053962 ], "sample_rate": 48000 } \ No newline at end of file diff --git a/generate_weights_a2.py b/generate_weights_a2.py index 4aca1c1..1186928 100644 --- a/generate_weights_a2.py +++ b/generate_weights_a2.py @@ -33,23 +33,44 @@ def count_conv1x1_weights(in_channels: int, out_channels: int, return weight_count -def count_film_weights(condition_dim: int, input_dim: int, has_shift: bool) -> int: +def count_film_weights(condition_dim: int, input_dim: int, has_shift: bool, groups: int = 1) -> int: """ Count weights for a FiLM (Feature-wise Linear Modulation) module. - FiLM uses a Conv1x1: condition_dim -> (2*input_dim if shift else input_dim), with bias + FiLM uses a Conv1x1: condition_dim -> (2*input_dim if shift else input_dim), with bias. + + Args: + condition_dim: Size of the conditioning input + input_dim: Size of the input to be modulated + has_shift: Whether to apply both scale and shift (true) or only scale (false) + groups: Number of groups for grouped convolution (default: 1) """ out_channels = (2 * input_dim) if has_shift else input_dim - return count_conv1x1_weights(condition_dim, out_channels, has_bias=True, groups=1) + return count_conv1x1_weights(condition_dim, out_channels, has_bias=True, groups=groups) -def parse_gating_mode(layer_config: Dict[str, Any]) -> str: - """Parse gating mode from layer config (handles both old and new formats).""" +def parse_gating_mode(layer_config: Dict[str, Any], layer_index: int = 0) -> str: + """ + Parse gating mode from layer config (handles both old and new formats). + + Args: + layer_config: Layer configuration dictionary + layer_index: Index of the layer (for array-valued gating_mode) + """ if "gating_mode" in layer_config: - gating_mode_str = layer_config["gating_mode"] - if gating_mode_str in ["GATED", "BLENDED", "NONE"]: - return gating_mode_str - # Handle lowercase versions - return gating_mode_str.upper() + gating_mode_value = layer_config["gating_mode"] + if isinstance(gating_mode_value, list): + # Array of gating modes - use the one at layer_index + gating_mode_str = gating_mode_value[layer_index] + else: + # Single gating mode - use for all layers + gating_mode_str = gating_mode_value + + if isinstance(gating_mode_str, str): + if gating_mode_str in ["GATED", "BLENDED", "NONE"]: + return gating_mode_str + # Handle lowercase versions + return gating_mode_str.upper() + return "NONE" elif "gated" in layer_config: # Backward compatibility return "GATED" if layer_config["gated"] else "NONE" @@ -57,7 +78,7 @@ def parse_gating_mode(layer_config: Dict[str, Any]) -> str: return "NONE" -def count_layer_weights(layer_config: Dict[str, Any], condition_size: int) -> int: +def count_layer_weights(layer_config: Dict[str, Any], condition_size: int, layer_index: int = 0) -> int: """ Count weights for a single layer (one dilation). @@ -67,6 +88,11 @@ def count_layer_weights(layer_config: Dict[str, Any], condition_size: int) -> in 3. 1x1 Conv1x1: (bottleneck, channels, bias=True, groups_1x1) 4. Optional head1x1 Conv1x1: (bottleneck, head1x1_out_channels, bias=True, head1x1_groups) 5. FiLM modules (optional, various configurations) + + Args: + layer_config: Layer configuration dictionary + condition_size: Size of the conditioning input + layer_index: Index of the layer within the layer array (for array-valued configs) """ channels = layer_config["channels"] bottleneck = layer_config.get("bottleneck", channels) @@ -75,7 +101,7 @@ def count_layer_weights(layer_config: Dict[str, Any], condition_size: int) -> in groups_input_mixin = layer_config.get("groups_input_mixin", 1) groups_1x1 = layer_config.get("groups_1x1", 1) - gating_mode = parse_gating_mode(layer_config) + gating_mode = parse_gating_mode(layer_config, layer_index) # Output channels are doubled for GATED and BLENDED modes conv_out_channels = 2 * bottleneck if gating_mode in ["GATED", "BLENDED"] else bottleneck @@ -128,8 +154,9 @@ def count_layer_weights(layer_config: Dict[str, Any], condition_size: int) -> in film_params = layer_config[film_key] if isinstance(film_params, dict) and film_params.get("active", True): has_shift = film_params.get("shift", True) + groups = film_params.get("groups", 1) if input_dim > 0: # Only count if input_dim is valid - weight_count += count_film_weights(condition_size, input_dim, has_shift) + weight_count += count_film_weights(condition_size, input_dim, has_shift, groups) return weight_count @@ -161,14 +188,27 @@ def count_layer_array_weights(layer_config: Dict[str, Any]) -> int: num_layers = len(dilations) + # Validate array-valued configs match number of layers + if "activation" in layer_config and isinstance(layer_config["activation"], list): + if len(layer_config["activation"]) != num_layers: + raise ValueError(f"activation array size ({len(layer_config['activation'])}) must match dilations size ({num_layers})") + + if "gating_mode" in layer_config and isinstance(layer_config["gating_mode"], list): + if len(layer_config["gating_mode"]) != num_layers: + raise ValueError(f"gating_mode array size ({len(layer_config['gating_mode'])}) must match dilations size ({num_layers})") + + if "secondary_activation" in layer_config and isinstance(layer_config["secondary_activation"], list): + if len(layer_config["secondary_activation"]) != num_layers: + raise ValueError(f"secondary_activation array size ({len(layer_config['secondary_activation'])}) must match dilations size ({num_layers})") + weight_count = 0 # 1. Rechannel weights weight_count += count_conv1x1_weights(input_size, channels, has_bias=False, groups=1) # 2. For each layer in the array - for _ in range(num_layers): - weight_count += count_layer_weights(layer_config, condition_size) + for layer_idx in range(num_layers): + weight_count += count_layer_weights(layer_config, condition_size, layer_idx) # 3. Head rechannel weights (input is head_output_size, not bottleneck) weight_count += count_conv1x1_weights(