Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ static nam::activations::ActivationSigmoid _SIGMOID;
static nam::activations::ActivationSwish _SWISH;
static nam::activations::ActivationHardSwish _HARD_SWISH;
static nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH;
static nam::activations::ActivationSoftsign _SOFTSIGN;

bool nam::activations::Activation::using_fast_tanh = false;

Expand All @@ -31,7 +32,8 @@ std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activati
{"SiLU", make_singleton_ptr(_SWISH)},
{"Hardswish", make_singleton_ptr(_HARD_SWISH)},
{"LeakyHardtanh", make_singleton_ptr(_LEAKY_HARD_TANH)},
{"PReLU", make_singleton_ptr(_PRELU)}};
{"PReLU", make_singleton_ptr(_PRELU)},
{"Softsign", make_singleton_ptr(_SOFTSIGN)}};

nam::activations::Activation::Ptr tanh_bak = nullptr;
nam::activations::Activation::Ptr sigmoid_bak = nullptr;
Expand Down Expand Up @@ -68,8 +70,8 @@ nam::activations::ActivationConfig nam::activations::ActivationConfig::from_json
{"SiLU", ActivationType::SiLU},
{"Hardswish", ActivationType::Hardswish},
{"LeakyHardtanh", ActivationType::LeakyHardtanh},
{"LeakyHardTanh", ActivationType::LeakyHardtanh} // Support both casings
};
{"LeakyHardTanh", ActivationType::LeakyHardtanh}, // Support both casings
{"Softsign", ActivationType::Softsign}};

// If it's a string, simple lookup
if (j.is_string())
Expand Down Expand Up @@ -156,6 +158,7 @@ nam::activations::Activation::Ptr nam::activations::Activation::get_activation(c
return std::make_shared<ActivationLeakyHardTanh>(config.min_val.value_or(-1.0f), config.max_val.value_or(1.0f),
config.min_slope.value_or(0.01f),
config.max_slope.value_or(0.01f));
case ActivationType::Softsign: return _activations["Softsign"];
default: return nullptr;
}
}
Expand Down
20 changes: 19 additions & 1 deletion NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ enum class ActivationType
Sigmoid,
SiLU, // aka Swish
Hardswish,
LeakyHardtanh
LeakyHardtanh,
Softsign
};

// Strongly-typed activation configuration
Expand Down Expand Up @@ -130,6 +131,11 @@ inline float hardswish(float x)
}
}

inline float softsign(float x)
{
return x / (1.0f + fabsf(x));
}

class Activation
{
public:
Expand Down Expand Up @@ -333,6 +339,18 @@ class ActivationHardSwish : public Activation
}
};

class ActivationSoftsign : public Activation
{
public:
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = softsign(data[pos]);
}
}
};

class FastLUTActivation : public Activation
{
public:
Expand Down
3 changes: 1 addition & 2 deletions example_models/wavenet_a2_max.nam
Original file line number Diff line number Diff line change
Expand Up @@ -1239,8 +1239,7 @@
2
],
"activation": {
"type": "PReLU",
"negative_slope": 0.015
"type": "Softsign"
},
"gating_mode": "none",
"head_bias": true,
Expand Down
7 changes: 7 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ int main()
test_activations::TestLeakyReLU::test_get_by_init();
test_activations::TestLeakyReLU::test_get_by_str();

test_activations::TestSoftsign::test_core_function();
test_activations::TestSoftsign::test_get_by_init();
test_activations::TestSoftsign::test_get_by_str();

test_lut::TestFastLUT::test_sigmoid();
test_lut::TestFastLUT::test_tanh();

Expand All @@ -53,9 +57,12 @@ int main()
test_activations::TestTypedActivationConfig::test_prelu_single_slope_config();
test_activations::TestTypedActivationConfig::test_prelu_multi_slope_config();
test_activations::TestTypedActivationConfig::test_leaky_hardtanh_config();
test_activations::TestTypedActivationConfig::test_softsign_config();
test_activations::TestTypedActivationConfig::test_from_json_string();
test_activations::TestTypedActivationConfig::test_from_json_object();
test_activations::TestTypedActivationConfig::test_from_json_prelu_multi();
test_activations::TestTypedActivationConfig::test_from_json_softsign_string();
test_activations::TestTypedActivationConfig::test_from_json_softsign_object();
test_activations::TestTypedActivationConfig::test_unknown_activation_throws();

test_dsp::test_construct();
Expand Down
96 changes: 92 additions & 4 deletions tools/test/test_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,61 @@ class TestLeakyReLU
}
};
};
class TestSoftsign
{
public:
static void test_core_function()
{
auto TestCase = [](float input, float expectedOutput) {
float actualOutput = nam::activations::softsign(input);
assert(fabs(actualOutput - expectedOutput) < 1e-6);
};
// Test cases for softsign: x / (1 + |x|)
TestCase(0.0f, 0.0f); // 0 / (1 + 0) = 0
TestCase(1.0f, 0.5f); // 1 / (1 + 1) = 0.5
TestCase(-1.0f, -0.5f); // -1 / (1 + 1) = -0.5
TestCase(2.0f, 2.0f / 3.0f); // 2 / (1 + 2) = 2/3
TestCase(-2.0f, -2.0f / 3.0f); // -2 / (1 + 2) = -2/3
};

static void test_get_by_init()
{
auto a = nam::activations::ActivationSoftsign();
_test_class(&a);
}

// Get the singleton and test it
static void test_get_by_str()
{
const std::string name = "Softsign";
auto a = nam::activations::Activation::get_activation(name);
_test_class(a.get());
}

private:
// Put the class through its paces
static void _test_class(nam::activations::Activation* a)
{
std::vector<float> inputs, expectedOutputs;

inputs.push_back(0.0f);
expectedOutputs.push_back(0.0f);

inputs.push_back(1.0f);
expectedOutputs.push_back(0.5f); // 1 / (1 + 1) = 0.5

inputs.push_back(-1.0f);
expectedOutputs.push_back(-0.5f); // -1 / (1 + 1) = -0.5

a->apply(inputs.data(), (long)inputs.size());
for (auto itActual = inputs.begin(), itExpected = expectedOutputs.begin(); itActual != inputs.end();
++itActual, ++itExpected)
{
assert(fabs(*itActual - *itExpected) < 1e-6);
}
};
};

class TestPReLU
{
public:
Expand Down Expand Up @@ -214,10 +269,10 @@ class TestTypedActivationConfig
{
// Test that all simple activation types work
std::vector<nam::activations::ActivationType> types = {
nam::activations::ActivationType::Tanh, nam::activations::ActivationType::Hardtanh,
nam::activations::ActivationType::Fasttanh, nam::activations::ActivationType::ReLU,
nam::activations::ActivationType::Sigmoid, nam::activations::ActivationType::SiLU,
nam::activations::ActivationType::Hardswish};
nam::activations::ActivationType::Tanh, nam::activations::ActivationType::Hardtanh,
nam::activations::ActivationType::Fasttanh, nam::activations::ActivationType::ReLU,
nam::activations::ActivationType::Sigmoid, nam::activations::ActivationType::SiLU,
nam::activations::ActivationType::Hardswish, nam::activations::ActivationType::Softsign};

for (auto type : types)
{
Expand Down Expand Up @@ -296,6 +351,23 @@ class TestTypedActivationConfig
assert(act != nullptr);
}

static void test_softsign_config()
{
// Test Softsign configuration
nam::activations::ActivationConfig config;
config.type = nam::activations::ActivationType::Softsign;

auto act = nam::activations::Activation::get_activation(config);
assert(act != nullptr);

// Verify the behavior
std::vector<float> data = {-1.0f, 0.0f, 1.0f};
act->apply(data.data(), (long)data.size());
assert(fabs(data[0] - (-0.5f)) < 1e-6); // -1 / (1 + 1) = -0.5
assert(fabs(data[1] - 0.0f) < 1e-6);
assert(fabs(data[2] - 0.5f) < 1e-6); // 1 / (1 + 1) = 0.5
}

static void test_from_json_string()
{
// Test from_json with string input
Expand Down Expand Up @@ -324,6 +396,22 @@ class TestTypedActivationConfig
assert(config.negative_slopes.value().size() == 4);
}

static void test_from_json_softsign_string()
{
// Test from_json with Softsign as string
nlohmann::json j = "Softsign";
auto config = nam::activations::ActivationConfig::from_json(j);
assert(config.type == nam::activations::ActivationType::Softsign);
}

static void test_from_json_softsign_object()
{
// Test from_json with Softsign as object (alpha parameter is ignored)
nlohmann::json j = {{"type", "Softsign"}, {"alpha", 0.5f}};
auto config = nam::activations::ActivationConfig::from_json(j);
assert(config.type == nam::activations::ActivationType::Softsign);
}

static void test_unknown_activation_throws()
{
// Test that unknown activation type throws
Expand Down