Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions NAM/film.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ class FiLM
/// \param condition_dim Size of the conditioning input
/// \param input_dim Size of the input to be modulated
/// \param shift Whether to apply both scale and shift (true) or only scale (false)
FiLM(const int condition_dim, const int input_dim, const bool shift)
: _cond_to_scale_shift(condition_dim, (shift ? 2 : 1) * input_dim, /*bias=*/true)
/// \param groups Number of groups for grouped convolution in the condition-to-scale-shift submodule (default: 1)
FiLM(const int condition_dim, const int input_dim, const bool shift, const int groups = 1)
: _cond_to_scale_shift(condition_dim, (shift ? 2 : 1) * input_dim, /*bias=*/true, groups)
, _do_shift(shift)
{
}
Expand Down
3 changes: 2 additions & 1 deletion NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,8 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
const nlohmann::json& film_config = layer_config[key];
bool active = film_config.value("active", true);
bool shift = film_config.value("shift", true);
return nam::wavenet::_FiLMParams(active, shift);
int groups = film_config.value("groups", 1);
return nam::wavenet::_FiLMParams(active, shift, groups);
};

// Parse FiLM parameters
Expand Down
36 changes: 22 additions & 14 deletions NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,16 @@ struct _FiLMParams
/// \brief Constructor
/// \param active_ Whether FiLM is active at this location
/// \param shift_ Whether to apply both scale and shift (true) or only scale (false)
_FiLMParams(bool active_, bool shift_)
/// \param groups_ Number of groups for grouped convolution in the condition-to-scale-shift submodule (default: 1)
_FiLMParams(bool active_, bool shift_, int groups_ = 1)
: active(active_)
, shift(shift_)
, groups(groups_)
{
}
const bool active; ///< Whether FiLM is active
const bool shift; ///< Whether to apply shift in addition to scale
const int groups; ///< Number of groups for grouped convolution in the condition-to-scale-shift submodule
};

/// \brief Parameters for constructing a single Layer
Expand Down Expand Up @@ -215,48 +218,53 @@ class _Layer
// Initialize FiLM objects
if (params.conv_pre_film_params.active)
{
_conv_pre_film =
std::make_unique<FiLM>(params.condition_size, params.channels, params.conv_pre_film_params.shift);
_conv_pre_film = std::make_unique<FiLM>(
params.condition_size, params.channels, params.conv_pre_film_params.shift, params.conv_pre_film_params.groups);
}
if (params.conv_post_film_params.active)
{
const int conv_out_channels =
(params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck;
_conv_post_film =
std::make_unique<FiLM>(params.condition_size, conv_out_channels, params.conv_post_film_params.shift);
_conv_post_film = std::make_unique<FiLM>(params.condition_size, conv_out_channels,
params.conv_post_film_params.shift, params.conv_post_film_params.groups);
}
if (params.input_mixin_pre_film_params.active)
{
_input_mixin_pre_film =
std::make_unique<FiLM>(params.condition_size, params.condition_size, params.input_mixin_pre_film_params.shift);
std::make_unique<FiLM>(params.condition_size, params.condition_size, params.input_mixin_pre_film_params.shift,
params.input_mixin_pre_film_params.groups);
}
if (params.input_mixin_post_film_params.active)
{
const int input_mixin_out_channels =
(params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck;
_input_mixin_post_film = std::make_unique<FiLM>(
params.condition_size, input_mixin_out_channels, params.input_mixin_post_film_params.shift);
_input_mixin_post_film =
std::make_unique<FiLM>(params.condition_size, input_mixin_out_channels,
params.input_mixin_post_film_params.shift, params.input_mixin_post_film_params.groups);
}
if (params.activation_pre_film_params.active)
{
const int z_channels = (params.gating_mode != GatingMode::NONE) ? 2 * params.bottleneck : params.bottleneck;
_activation_pre_film =
std::make_unique<FiLM>(params.condition_size, z_channels, params.activation_pre_film_params.shift);
std::make_unique<FiLM>(params.condition_size, z_channels, params.activation_pre_film_params.shift,
params.activation_pre_film_params.groups);
}
if (params.activation_post_film_params.active)
{
_activation_post_film =
std::make_unique<FiLM>(params.condition_size, params.bottleneck, params.activation_post_film_params.shift);
std::make_unique<FiLM>(params.condition_size, params.bottleneck, params.activation_post_film_params.shift,
params.activation_post_film_params.groups);
}
if (params._1x1_post_film_params.active)
{
_1x1_post_film =
std::make_unique<FiLM>(params.condition_size, params.channels, params._1x1_post_film_params.shift);
_1x1_post_film = std::make_unique<FiLM>(params.condition_size, params.channels,
params._1x1_post_film_params.shift, params._1x1_post_film_params.groups);
}
if (params.head1x1_post_film_params.active && params.head1x1_params.active)
{
_head1x1_post_film = std::make_unique<FiLM>(
params.condition_size, params.head1x1_params.out_channels, params.head1x1_post_film_params.shift);
_head1x1_post_film =
std::make_unique<FiLM>(params.condition_size, params.head1x1_params.out_channels,
params.head1x1_post_film_params.shift, params.head1x1_post_film_params.groups);
}
};

Expand Down
Loading