diff --git a/index.bs b/index.bs index c011e08a..83c8138e 100644 --- a/index.bs +++ b/index.bs @@ -8270,7 +8270,8 @@ partial dictionary MLOpSupportLimits { } function reduceLogSumExp(builder, input, options) { - return builder.log(builder.reduceSum(builder.exp(input), options)); + const maxX = builder.reduceMax(input, {axes: options.axes, keepDimensions: true}); + return builder.add(maxX, builder.log(builder.reduceSum(builder.exp(builder.sub(input, maxX)), options))); } function reduceSumSquare(builder, input, options) { @@ -9486,8 +9487,7 @@ partial dictionary MLOpSupportLimits {
function softplus(builder, input) {
- return builder.log(
- builder.add(builder.exp(input), builder.constant(input.dataType, 1)));
+ return builder.add(builder.max(input, builder.constant(input.dataType, 0)), builder.log(builder.add(builder.constant(input.dataType, 1), builder.exp(builder.neg(builder.abs(input))))));
}
@@ -9778,13 +9778,13 @@ partial dictionary MLOpSupportLimits {
function tanh(builder, input) {
- return builder.div(
+ return builder.mul(builder.sign(input), builder.div(
builder.sub(
- builder.exp(builder.mul(builder.constant(input.dataType, 2), input)),
- builder.constant(input.dataType, 1)),
+ builder.constant(input.dataType, 1),
+ builder.exp(builder.mul(builder.constant(input.dataType, -2), builder.abs(input)))),
builder.add(
- builder.exp(builder.mul(builder.constant(input.dataType, 2), input)),
- builder.constant(input.dataType, 1)));
+ builder.constant(input.dataType, 1),
+ builder.exp(builder.mul(builder.constant(input.dataType, -2), builder.abs(input))))));
}