diff --git a/NAM/activations.cpp b/NAM/activations.cpp index 4127d5a..a476520 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -11,6 +11,7 @@ static nam::activations::ActivationSigmoid _SIGMOID; static nam::activations::ActivationSwish _SWISH; static nam::activations::ActivationHardSwish _HARD_SWISH; static nam::activations::ActivationLeakyHardTanh _LEAKY_HARD_TANH; +static nam::activations::ActivationSoftsign _SOFTSIGN; bool nam::activations::Activation::using_fast_tanh = false; @@ -31,7 +32,8 @@ std::unordered_map nam::activati {"SiLU", make_singleton_ptr(_SWISH)}, {"Hardswish", make_singleton_ptr(_HARD_SWISH)}, {"LeakyHardtanh", make_singleton_ptr(_LEAKY_HARD_TANH)}, - {"PReLU", make_singleton_ptr(_PRELU)}}; + {"PReLU", make_singleton_ptr(_PRELU)}, + {"Softsign", make_singleton_ptr(_SOFTSIGN)}}; nam::activations::Activation::Ptr tanh_bak = nullptr; nam::activations::Activation::Ptr sigmoid_bak = nullptr; @@ -68,8 +70,8 @@ nam::activations::ActivationConfig nam::activations::ActivationConfig::from_json {"SiLU", ActivationType::SiLU}, {"Hardswish", ActivationType::Hardswish}, {"LeakyHardtanh", ActivationType::LeakyHardtanh}, - {"LeakyHardTanh", ActivationType::LeakyHardtanh} // Support both casings - }; + {"LeakyHardTanh", ActivationType::LeakyHardtanh}, // Support both casings + {"Softsign", ActivationType::Softsign}}; // If it's a string, simple lookup if (j.is_string()) @@ -156,6 +158,7 @@ nam::activations::Activation::Ptr nam::activations::Activation::get_activation(c return std::make_shared(config.min_val.value_or(-1.0f), config.max_val.value_or(1.0f), config.min_slope.value_or(0.01f), config.max_slope.value_or(0.01f)); + case ActivationType::Softsign: return _activations["Softsign"]; default: return nullptr; } } diff --git a/NAM/activations.h b/NAM/activations.h index d30e4dd..68d5025 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -33,7 +33,8 @@ enum class ActivationType Sigmoid, SiLU, // aka Swish Hardswish, - LeakyHardtanh + LeakyHardtanh, + Softsign }; // Strongly-typed activation configuration @@ -130,6 +131,11 @@ inline float hardswish(float x) } } +inline float softsign(float x) +{ + return x / (1.0f + fabsf(x)); +} + class Activation { public: @@ -333,6 +339,18 @@ class ActivationHardSwish : public Activation } }; +class ActivationSoftsign : public Activation +{ +public: + void apply(float* data, long size) override + { + for (long pos = 0; pos < size; pos++) + { + data[pos] = softsign(data[pos]); + } + } +}; + class FastLUTActivation : public Activation { public: diff --git a/example_models/wavenet_a2_max.nam b/example_models/wavenet_a2_max.nam index 1a48c95..f54f396 100644 --- a/example_models/wavenet_a2_max.nam +++ b/example_models/wavenet_a2_max.nam @@ -1239,8 +1239,7 @@ 2 ], "activation": { - "type": "PReLU", - "negative_slope": 0.015 + "type": "Softsign" }, "gating_mode": "none", "head_bias": true, diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 218881e..56abfec 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -38,6 +38,10 @@ int main() test_activations::TestLeakyReLU::test_get_by_init(); test_activations::TestLeakyReLU::test_get_by_str(); + test_activations::TestSoftsign::test_core_function(); + test_activations::TestSoftsign::test_get_by_init(); + test_activations::TestSoftsign::test_get_by_str(); + test_lut::TestFastLUT::test_sigmoid(); test_lut::TestFastLUT::test_tanh(); @@ -53,9 +57,12 @@ int main() test_activations::TestTypedActivationConfig::test_prelu_single_slope_config(); test_activations::TestTypedActivationConfig::test_prelu_multi_slope_config(); test_activations::TestTypedActivationConfig::test_leaky_hardtanh_config(); + test_activations::TestTypedActivationConfig::test_softsign_config(); test_activations::TestTypedActivationConfig::test_from_json_string(); test_activations::TestTypedActivationConfig::test_from_json_object(); test_activations::TestTypedActivationConfig::test_from_json_prelu_multi(); + test_activations::TestTypedActivationConfig::test_from_json_softsign_string(); + test_activations::TestTypedActivationConfig::test_from_json_softsign_object(); test_activations::TestTypedActivationConfig::test_unknown_activation_throws(); test_dsp::test_construct(); diff --git a/tools/test/test_activations.cpp b/tools/test/test_activations.cpp index abbdd23..a8dd705 100644 --- a/tools/test/test_activations.cpp +++ b/tools/test/test_activations.cpp @@ -120,6 +120,61 @@ class TestLeakyReLU } }; }; +class TestSoftsign +{ +public: + static void test_core_function() + { + auto TestCase = [](float input, float expectedOutput) { + float actualOutput = nam::activations::softsign(input); + assert(fabs(actualOutput - expectedOutput) < 1e-6); + }; + // Test cases for softsign: x / (1 + |x|) + TestCase(0.0f, 0.0f); // 0 / (1 + 0) = 0 + TestCase(1.0f, 0.5f); // 1 / (1 + 1) = 0.5 + TestCase(-1.0f, -0.5f); // -1 / (1 + 1) = -0.5 + TestCase(2.0f, 2.0f / 3.0f); // 2 / (1 + 2) = 2/3 + TestCase(-2.0f, -2.0f / 3.0f); // -2 / (1 + 2) = -2/3 + }; + + static void test_get_by_init() + { + auto a = nam::activations::ActivationSoftsign(); + _test_class(&a); + } + + // Get the singleton and test it + static void test_get_by_str() + { + const std::string name = "Softsign"; + auto a = nam::activations::Activation::get_activation(name); + _test_class(a.get()); + } + +private: + // Put the class through its paces + static void _test_class(nam::activations::Activation* a) + { + std::vector inputs, expectedOutputs; + + inputs.push_back(0.0f); + expectedOutputs.push_back(0.0f); + + inputs.push_back(1.0f); + expectedOutputs.push_back(0.5f); // 1 / (1 + 1) = 0.5 + + inputs.push_back(-1.0f); + expectedOutputs.push_back(-0.5f); // -1 / (1 + 1) = -0.5 + + a->apply(inputs.data(), (long)inputs.size()); + for (auto itActual = inputs.begin(), itExpected = expectedOutputs.begin(); itActual != inputs.end(); + ++itActual, ++itExpected) + { + assert(fabs(*itActual - *itExpected) < 1e-6); + } + }; +}; + class TestPReLU { public: @@ -214,10 +269,10 @@ class TestTypedActivationConfig { // Test that all simple activation types work std::vector types = { - nam::activations::ActivationType::Tanh, nam::activations::ActivationType::Hardtanh, - nam::activations::ActivationType::Fasttanh, nam::activations::ActivationType::ReLU, - nam::activations::ActivationType::Sigmoid, nam::activations::ActivationType::SiLU, - nam::activations::ActivationType::Hardswish}; + nam::activations::ActivationType::Tanh, nam::activations::ActivationType::Hardtanh, + nam::activations::ActivationType::Fasttanh, nam::activations::ActivationType::ReLU, + nam::activations::ActivationType::Sigmoid, nam::activations::ActivationType::SiLU, + nam::activations::ActivationType::Hardswish, nam::activations::ActivationType::Softsign}; for (auto type : types) { @@ -296,6 +351,23 @@ class TestTypedActivationConfig assert(act != nullptr); } + static void test_softsign_config() + { + // Test Softsign configuration + nam::activations::ActivationConfig config; + config.type = nam::activations::ActivationType::Softsign; + + auto act = nam::activations::Activation::get_activation(config); + assert(act != nullptr); + + // Verify the behavior + std::vector data = {-1.0f, 0.0f, 1.0f}; + act->apply(data.data(), (long)data.size()); + assert(fabs(data[0] - (-0.5f)) < 1e-6); // -1 / (1 + 1) = -0.5 + assert(fabs(data[1] - 0.0f) < 1e-6); + assert(fabs(data[2] - 0.5f) < 1e-6); // 1 / (1 + 1) = 0.5 + } + static void test_from_json_string() { // Test from_json with string input @@ -324,6 +396,22 @@ class TestTypedActivationConfig assert(config.negative_slopes.value().size() == 4); } + static void test_from_json_softsign_string() + { + // Test from_json with Softsign as string + nlohmann::json j = "Softsign"; + auto config = nam::activations::ActivationConfig::from_json(j); + assert(config.type == nam::activations::ActivationType::Softsign); + } + + static void test_from_json_softsign_object() + { + // Test from_json with Softsign as object (alpha parameter is ignored) + nlohmann::json j = {{"type", "Softsign"}, {"alpha", 0.5f}}; + auto config = nam::activations::ActivationConfig::from_json(j); + assert(config.type == nam::activations::ActivationType::Softsign); + } + static void test_unknown_activation_throws() { // Test that unknown activation type throws