diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp index 6e1835b..f05dd07 100644 --- a/NAM/conv1d.cpp +++ b/NAM/conv1d.cpp @@ -55,10 +55,16 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int this->_num_groups = groups; this->_weight.resize(kernel_size); for (size_t i = 0; i < this->_weight.size(); i++) + { this->_weight[i].resize(out_channels, in_channels); // y = Ax, input array (C,L) + this->_weight[i].setZero(); + } if (do_bias) + { this->_bias.resize(out_channels); + this->_bias.setZero(); + } else this->_bias.resize(0); this->_dilation = _dilation; @@ -104,54 +110,22 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) // Zero output before processing _output.leftCols(num_frames).setZero(); - const int numGroups = this->_num_groups; - const long in_channels = get_in_channels(); - const long out_channels = get_out_channels(); - const long in_per_group = in_channels / numGroups; - const long out_per_group = out_channels / numGroups; - // Process from ring buffer with dilation lookback // After Write(), data is at positions [_write_pos, _write_pos+num_frames-1] // For kernel tap k with offset, we need to read from _write_pos + offset // The offset is negative (looking back), so _write_pos + offset reads from earlier positions - // The original process_() reads: input.middleCols(i_start + offset, ncols) - // where i_start is the current position and offset is negative for lookback - - if (numGroups == 1) + // + // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal), + // so we can use a single GEMM for all cases. A more advanced implementation could store + // compact per-group weight matrices and loop over groups, but at typical model sizes + // (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate + // and the single sparse GEMM approach is faster. + for (size_t k = 0; k < this->_weight.size(); k++) { - // Standard convolution (no grouping) - for (size_t k = 0; k < this->_weight.size(); k++) - { - const long offset = this->_dilation * (k + 1 - (long)this->_weight.size()); - const long lookback = -offset; - auto input_block = _input_buffer.Read(num_frames, lookback); - _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; - } - } - else - { - // Grouped convolution: process each group separately - for (int g = 0; g < numGroups; g++) - { - for (size_t k = 0; k < this->_weight.size(); k++) - { - const long offset = this->_dilation * (k + 1 - (long)this->_weight.size()); - const long lookback = -offset; - auto input_block = _input_buffer.Read(num_frames, lookback); - - // Extract input slice for this group - auto input_group = input_block.middleRows(g * in_per_group, in_per_group); - - // Extract weight slice for this group - auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group); - - // Extract output slice for this group - auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group); - - // Perform grouped convolution: output_group += weight_group * input_group - output_group.noalias() += weight_group * input_group; - } - } + const long offset = this->_dilation * (k + 1 - (long)this->_weight.size()); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; } // Add bias if present @@ -167,49 +141,18 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols, const long j_start) const { - const int numGroups = this->_num_groups; - const long in_channels = get_in_channels(); - const long out_channels = get_out_channels(); - const long in_per_group = in_channels / numGroups; - const long out_per_group = out_channels / numGroups; - - if (numGroups == 1) - { - // Standard convolution (no grouping) - for (size_t k = 0; k < this->_weight.size(); k++) - { - const long offset = this->_dilation * (k + 1 - this->_weight.size()); - if (k == 0) - output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols); - else - output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols); - } - } - else + // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal), + // so we can use a single GEMM for all cases. A more advanced implementation could store + // compact per-group weight matrices and loop over groups, but at typical model sizes + // (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate + // and the single sparse GEMM approach is faster. + for (size_t k = 0; k < this->_weight.size(); k++) { - // Grouped convolution: process each group separately - for (int g = 0; g < numGroups; g++) - { - for (size_t k = 0; k < this->_weight.size(); k++) - { - const long offset = this->_dilation * (k + 1 - this->_weight.size()); - - // Extract input slice for this group - auto input_group = input.middleCols(i_start + offset, ncols).middleRows(g * in_per_group, in_per_group); - - // Extract weight slice for this group - auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group); - - // Extract output slice for this group - auto output_group = output.middleCols(j_start, ncols).middleRows(g * out_per_group, out_per_group); - - // Perform grouped convolution - if (k == 0) - output_group.noalias() = weight_group * input_group; - else - output_group.noalias() += weight_group * input_group; - } - } + const long offset = this->_dilation * (k + 1 - this->_weight.size()); + if (k == 0) + output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols); + else + output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols); } if (this->_bias.size() > 0) { diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 02a4a13..b7f5f3f 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -332,9 +332,13 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool this->_num_groups = groups; this->_weight.resize(out_channels, in_channels); + this->_weight.setZero(); this->_do_bias = _bias; if (_bias) + { this->_bias.resize(out_channels); + this->_bias.setZero(); + } } @@ -374,45 +378,11 @@ void nam::Conv1x1::set_weights_(std::vector::iterator& weights) Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int num_frames) const { - const int numGroups = this->_num_groups; - const long in_channels = get_in_channels(); - const long out_channels = get_out_channels(); - const long in_per_group = in_channels / numGroups; - const long out_per_group = out_channels / numGroups; - - Eigen::MatrixXf result(out_channels, num_frames); - - if (numGroups == 1) - { - // Standard convolution (no grouping) - if (this->_do_bias) - result = (this->_weight * input.leftCols(num_frames)).colwise() + this->_bias; - else - result = this->_weight * input.leftCols(num_frames); - } - else - { - // Grouped convolution: process each group separately - result.setZero(); - for (int g = 0; g < numGroups; g++) - { - // Extract input slice for this group - auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group); - - // Extract weight slice for this group - auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group); - - // Extract output slice for this group - auto output_group = result.middleRows(g * out_per_group, out_per_group); + // Single GEMM for all cases - block-diagonal zero structure handles grouping + Eigen::MatrixXf result = this->_weight * input.leftCols(num_frames); - // Perform grouped convolution: output_group = weight_group * input_group - output_group.noalias() = weight_group * input_group; - } - - // Add bias if present - if (this->_do_bias) - result.colwise() += this->_bias; - } + if (this->_do_bias) + result.colwise() += this->_bias; return result; } @@ -421,40 +391,9 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons { assert(num_frames <= _output.cols()); - const int numGroups = this->_num_groups; - const long in_channels = get_in_channels(); - const long out_channels = get_out_channels(); - const long in_per_group = in_channels / numGroups; - const long out_per_group = out_channels / numGroups; - - if (numGroups == 1) - { - // Standard convolution (no grouping) - _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); - } - else - { - // Grouped convolution: process each group separately - _output.leftCols(num_frames).setZero(); - for (int g = 0; g < numGroups; g++) - { - // Extract input slice for this group - auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group); - - // Extract weight slice for this group - auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group); - - // Extract output slice for this group - auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group); + // Single GEMM for all cases - block-diagonal zero structure handles grouping + _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); - // Perform grouped convolution: output_group = weight_group * input_group - output_group.noalias() = weight_group * input_group; - } - } - - // Add bias if present if (this->_do_bias) - { _output.leftCols(num_frames).colwise() += this->_bias; - } }