diff --git a/index.bs b/index.bs index c011e08a..83c8138e 100644 --- a/index.bs +++ b/index.bs @@ -8270,7 +8270,8 @@ partial dictionary MLOpSupportLimits { } function reduceLogSumExp(builder, input, options) { - return builder.log(builder.reduceSum(builder.exp(input), options)); + const maxX = builder.reduceMax(input, {axes: options.axes, keepDimensions: true}); + return builder.add(maxX, builder.log(builder.reduceSum(builder.exp(builder.sub(input, maxX)), options))); } function reduceSumSquare(builder, input, options) { @@ -9486,8 +9487,7 @@ partial dictionary MLOpSupportLimits {
     function softplus(builder, input) {
-      return builder.log(
-        builder.add(builder.exp(input), builder.constant(input.dataType, 1)));
+      return builder.add(builder.max(input, builder.constant(input.dataType, 0)), builder.log(builder.add(builder.constant(input.dataType, 1), builder.exp(builder.neg(builder.abs(input))))));
     }
     
@@ -9778,13 +9778,13 @@ partial dictionary MLOpSupportLimits {
     function tanh(builder, input) {
-      return builder.div(
+      return builder.mul(builder.sign(input), builder.div(
         builder.sub(
-          builder.exp(builder.mul(builder.constant(input.dataType, 2), input)),
-          builder.constant(input.dataType, 1)),
+          builder.constant(input.dataType, 1),
+          builder.exp(builder.mul(builder.constant(input.dataType, -2), builder.abs(input)))),
         builder.add(
-          builder.exp(builder.mul(builder.constant(input.dataType, 2), input)),
-          builder.constant(input.dataType, 1)));
+          builder.constant(input.dataType, 1),
+          builder.exp(builder.mul(builder.constant(input.dataType, -2), builder.abs(input))))));
     }