Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 28 additions & 85 deletions NAM/conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int
this->_num_groups = groups;
this->_weight.resize(kernel_size);
for (size_t i = 0; i < this->_weight.size(); i++)
{
this->_weight[i].resize(out_channels,
in_channels); // y = Ax, input array (C,L)
this->_weight[i].setZero();
}
if (do_bias)
{
this->_bias.resize(out_channels);
this->_bias.setZero();
}
else
this->_bias.resize(0);
this->_dilation = _dilation;
Expand Down Expand Up @@ -104,54 +110,22 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
// Zero output before processing
_output.leftCols(num_frames).setZero();

const int numGroups = this->_num_groups;
const long in_channels = get_in_channels();
const long out_channels = get_out_channels();
const long in_per_group = in_channels / numGroups;
const long out_per_group = out_channels / numGroups;

// Process from ring buffer with dilation lookback
// After Write(), data is at positions [_write_pos, _write_pos+num_frames-1]
// For kernel tap k with offset, we need to read from _write_pos + offset
// The offset is negative (looking back), so _write_pos + offset reads from earlier positions
// The original process_() reads: input.middleCols(i_start + offset, ncols)
// where i_start is the current position and offset is negative for lookback

if (numGroups == 1)
//
// Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
// so we can use a single GEMM for all cases. A more advanced implementation could store
// compact per-group weight matrices and loop over groups, but at typical model sizes
// (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
// and the single sparse GEMM approach is faster.
for (size_t k = 0; k < this->_weight.size(); k++)
{
// Standard convolution (no grouping)
for (size_t k = 0; k < this->_weight.size(); k++)
{
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
const long lookback = -offset;
auto input_block = _input_buffer.Read(num_frames, lookback);
_output.leftCols(num_frames).noalias() += this->_weight[k] * input_block;
}
}
else
{
// Grouped convolution: process each group separately
for (int g = 0; g < numGroups; g++)
{
for (size_t k = 0; k < this->_weight.size(); k++)
{
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
const long lookback = -offset;
auto input_block = _input_buffer.Read(num_frames, lookback);

// Extract input slice for this group
auto input_group = input_block.middleRows(g * in_per_group, in_per_group);

// Extract weight slice for this group
auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);

// Extract output slice for this group
auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group);

// Perform grouped convolution: output_group += weight_group * input_group
output_group.noalias() += weight_group * input_group;
}
}
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
const long lookback = -offset;
auto input_block = _input_buffer.Read(num_frames, lookback);
_output.leftCols(num_frames).noalias() += this->_weight[k] * input_block;
}

// Add bias if present
Expand All @@ -167,49 +141,18 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols,
const long j_start) const
{
const int numGroups = this->_num_groups;
const long in_channels = get_in_channels();
const long out_channels = get_out_channels();
const long in_per_group = in_channels / numGroups;
const long out_per_group = out_channels / numGroups;

if (numGroups == 1)
{
// Standard convolution (no grouping)
for (size_t k = 0; k < this->_weight.size(); k++)
{
const long offset = this->_dilation * (k + 1 - this->_weight.size());
if (k == 0)
output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols);
else
output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols);
}
}
else
// Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
// so we can use a single GEMM for all cases. A more advanced implementation could store
// compact per-group weight matrices and loop over groups, but at typical model sizes
// (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
// and the single sparse GEMM approach is faster.
for (size_t k = 0; k < this->_weight.size(); k++)
{
// Grouped convolution: process each group separately
for (int g = 0; g < numGroups; g++)
{
for (size_t k = 0; k < this->_weight.size(); k++)
{
const long offset = this->_dilation * (k + 1 - this->_weight.size());

// Extract input slice for this group
auto input_group = input.middleCols(i_start + offset, ncols).middleRows(g * in_per_group, in_per_group);

// Extract weight slice for this group
auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);

// Extract output slice for this group
auto output_group = output.middleCols(j_start, ncols).middleRows(g * out_per_group, out_per_group);

// Perform grouped convolution
if (k == 0)
output_group.noalias() = weight_group * input_group;
else
output_group.noalias() += weight_group * input_group;
}
}
const long offset = this->_dilation * (k + 1 - this->_weight.size());
if (k == 0)
output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols);
else
output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols);
}
if (this->_bias.size() > 0)
{
Expand Down
81 changes: 10 additions & 71 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,13 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool

this->_num_groups = groups;
this->_weight.resize(out_channels, in_channels);
this->_weight.setZero();
this->_do_bias = _bias;
if (_bias)
{
this->_bias.resize(out_channels);
this->_bias.setZero();
}
}


Expand Down Expand Up @@ -374,45 +378,11 @@ void nam::Conv1x1::set_weights_(std::vector<float>::iterator& weights)

Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int num_frames) const
{
const int numGroups = this->_num_groups;
const long in_channels = get_in_channels();
const long out_channels = get_out_channels();
const long in_per_group = in_channels / numGroups;
const long out_per_group = out_channels / numGroups;

Eigen::MatrixXf result(out_channels, num_frames);

if (numGroups == 1)
{
// Standard convolution (no grouping)
if (this->_do_bias)
result = (this->_weight * input.leftCols(num_frames)).colwise() + this->_bias;
else
result = this->_weight * input.leftCols(num_frames);
}
else
{
// Grouped convolution: process each group separately
result.setZero();
for (int g = 0; g < numGroups; g++)
{
// Extract input slice for this group
auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group);

// Extract weight slice for this group
auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);

// Extract output slice for this group
auto output_group = result.middleRows(g * out_per_group, out_per_group);
// Single GEMM for all cases - block-diagonal zero structure handles grouping
Eigen::MatrixXf result = this->_weight * input.leftCols(num_frames);

// Perform grouped convolution: output_group = weight_group * input_group
output_group.noalias() = weight_group * input_group;
}

// Add bias if present
if (this->_do_bias)
result.colwise() += this->_bias;
}
if (this->_do_bias)
result.colwise() += this->_bias;

return result;
}
Expand All @@ -421,40 +391,9 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
{
assert(num_frames <= _output.cols());

const int numGroups = this->_num_groups;
const long in_channels = get_in_channels();
const long out_channels = get_out_channels();
const long in_per_group = in_channels / numGroups;
const long out_per_group = out_channels / numGroups;

if (numGroups == 1)
{
// Standard convolution (no grouping)
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);
}
else
{
// Grouped convolution: process each group separately
_output.leftCols(num_frames).setZero();
for (int g = 0; g < numGroups; g++)
{
// Extract input slice for this group
auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group);

// Extract weight slice for this group
auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);

// Extract output slice for this group
auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group);
// Single GEMM for all cases - block-diagonal zero structure handles grouping
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);

// Perform grouped convolution: output_group = weight_group * input_group
output_group.noalias() = weight_group * input_group;
}
}

// Add bias if present
if (this->_do_bias)
{
_output.leftCols(num_frames).colwise() += this->_bias;
}
}