Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 65 additions & 26 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
_input_mixin.SetMaxBufferSize(maxBufferSize);
const long z_channels = this->_conv.get_out_channels(); // This is 2*bottleneck when gated, bottleneck when not
_z.resize(z_channels, maxBufferSize);
_1x1.SetMaxBufferSize(maxBufferSize);
if (this->_layer1x1)
{
this->_layer1x1->SetMaxBufferSize(maxBufferSize);
}
// Pre-allocate output buffers
const long channels = this->get_channels();
this->_output_next_layer.resize(channels, maxBufferSize);
Expand Down Expand Up @@ -47,8 +50,8 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
this->_activation_pre_film->SetMaxBufferSize(maxBufferSize);
if (this->_activation_post_film)
this->_activation_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_1x1_post_film)
this->_1x1_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_layer1x1_post_film)
this->_layer1x1_post_film->SetMaxBufferSize(maxBufferSize);
if (this->_head1x1_post_film)
this->_head1x1_post_film->SetMaxBufferSize(maxBufferSize);
}
Expand All @@ -57,7 +60,10 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
{
this->_conv.set_weights_(weights);
this->_input_mixin.set_weights_(weights);
this->_1x1.set_weights_(weights);
if (this->_layer1x1)
{
this->_layer1x1->set_weights_(weights);
}
if (this->_head1x1)
{
this->_head1x1->set_weights_(weights);
Expand All @@ -75,8 +81,8 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
this->_activation_pre_film->set_weights_(weights);
if (this->_activation_post_film)
this->_activation_post_film->set_weights_(weights);
if (this->_1x1_post_film)
this->_1x1_post_film->set_weights_(weights);
if (this->_layer1x1_post_film)
this->_layer1x1_post_film->set_weights_(weights);
if (this->_head1x1_post_film)
this->_head1x1_post_film->set_weights_(weights);
}
Expand Down Expand Up @@ -137,7 +143,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
{
this->_activation_post_film->Process_(this->_z, condition, num_frames);
}
_1x1.process_(_z, num_frames);
if (this->_layer1x1)
{
this->_layer1x1->process_(this->_z, num_frames);
}
}
else if (this->_gating_mode == GatingMode::GATED)
{
Expand All @@ -153,7 +162,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
this->_activation_post_film->GetOutput().leftCols(num_frames);
}
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
if (this->_layer1x1)
{
this->_layer1x1->process_(this->_z.topRows(bottleneck), num_frames);
}
}
else if (this->_gating_mode == GatingMode::BLENDED)
{
Expand All @@ -169,11 +181,14 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
this->_activation_post_film->GetOutput().leftCols(num_frames);
}
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
if (this->_1x1_post_film)
if (this->_layer1x1)
{
Eigen::MatrixXf& _1x1_output = this->_1x1.GetOutput();
this->_1x1_post_film->Process_(_1x1_output, condition, num_frames);
this->_layer1x1->process_(this->_z.topRows(bottleneck), num_frames);
if (this->_layer1x1_post_film)
{
Eigen::MatrixXf& layer1x1_output = this->_layer1x1->GetOutput();
this->_layer1x1_post_film->Process_(layer1x1_output, condition, num_frames);
}
}
}

Expand All @@ -187,7 +202,6 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
{
this->_head1x1->process_(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
}
this->_head1x1->process(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
if (this->_head1x1_post_film)
{
Eigen::MatrixXf& head1x1_output = this->_head1x1->GetOutput();
Expand All @@ -205,9 +219,17 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
}

// Store output to next layer (residual connection: input + _1x1 output)
this->_output_next_layer.leftCols(num_frames).noalias() =
input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames);
// Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive)
if (this->_layer1x1)
{
this->_output_next_layer.leftCols(num_frames).noalias() =
input.leftCols(num_frames) + this->_layer1x1->GetOutput().leftCols(num_frames);
}
else
{
// If layer1x1 is inactive, residual connection is just the input (identity)
this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames);
}
}

// LayerArray =================================================================
Expand All @@ -224,10 +246,10 @@ nam::wavenet::_LayerArray::_LayerArray(const LayerArrayParams& params)
LayerParams layer_params(
params.condition_size, params.channels, params.bottleneck, params.kernel_size, params.dilations[i],
params.activation_configs[i], params.gating_modes[i], params.groups_input, params.groups_input_mixin,
params.groups_1x1, params.head1x1_params, params.secondary_activation_configs[i], params.conv_pre_film_params,
params.conv_post_film_params, params.input_mixin_pre_film_params, params.input_mixin_post_film_params,
params.activation_pre_film_params, params.activation_post_film_params, params._1x1_post_film_params,
params.head1x1_post_film_params);
params.layer1x1_params, params.head1x1_params, params.secondary_activation_configs[i],
params.conv_pre_film_params, params.conv_post_film_params, params.input_mixin_pre_film_params,
params.input_mixin_post_film_params, params.activation_pre_film_params, params.activation_post_film_params,
params._layer1x1_post_film_params, params.head1x1_post_film_params);
this->_layers.push_back(_Layer(layer_params));
}
}
Expand Down Expand Up @@ -570,11 +592,21 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st

const int groups = layer_config.value("groups_input", 1); // defaults to 1
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1

const int channels = layer_config["channels"];
const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present

// Parse layer1x1 parameters
bool layer1x1_active = true; // default to active if not present
int layer1x1_groups = 1;
if (layer_config.find("layer1x1") != layer_config.end())
{
const auto& layer1x1_config = layer_config["layer1x1"];
layer1x1_active = layer1x1_config["active"]; // default to active
layer1x1_groups = layer1x1_config["groups"];
}
nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups);

const int input_size = layer_config["input_size"];
const int condition_size = layer_config["condition_size"];
const int head_size = layer_config["head_size"];
Expand Down Expand Up @@ -742,9 +774,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
bool head1x1_active = false;
int head1x1_out_channels = channels;
int head1x1_groups = 1;
if (layer_config.find("head_1x1") != layer_config.end())
if (layer_config.find("head1x1") != layer_config.end())
{
const auto& head1x1_config = layer_config["head_1x1"];
const auto& head1x1_config = layer_config["head1x1"];
head1x1_active = head1x1_config["active"];
head1x1_out_channels = head1x1_config["out_channels"];
head1x1_groups = head1x1_config["groups"];
Expand All @@ -771,15 +803,22 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params("input_mixin_post_film");
nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params("activation_pre_film");
nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params("activation_post_film");
nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params("1x1_post_film");
nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film");
nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film");

// Validation: if layer1x1_post_film is active, layer1x1 must also be active
if (_layer1x1_post_film_params.active && !layer1x1_active)
{
throw std::runtime_error("Layer array " + std::to_string(i)
+ ": layer1x1_post_film cannot be active when layer1x1.active is false");
}

layer_array_params.push_back(nam::wavenet::LayerArrayParams(
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations,
std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, groups_1x1,
std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params,
head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params,
input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params));
}
const bool with_head = !config["head"].is_null();
const float head_scale = config["head_scale"];
Expand Down
Loading