diff --git a/src/Picasso_BatchedLinearAlgebra.hpp b/src/Picasso_BatchedLinearAlgebra.hpp index 74c3a48..d616176 100644 --- a/src/Picasso_BatchedLinearAlgebra.hpp +++ b/src/Picasso_BatchedLinearAlgebra.hpp @@ -4523,6 +4523,34 @@ KOKKOS_INLINE_FUNCTION auto contract( const ExpressionT& t, return res; } +template ::value && + is_matrix::value, + int> = 0> +KOKKOS_INLINE_FUNCTION auto contract( const ExpressionT& t, + const ExpressionM& m ) +{ + static_assert( ExpressionT::extent_1 == ExpressionM::extent_0, + "Inner extents must match" ); + static_assert( ExpressionT::extent_2 == ExpressionM::extent_1, + "Inner extents must match " ); + + typename ExpressionT::eval_type t_eval = t; + typename ExpressionM::eval_type m_eval = m; + Vector res = + static_cast( 0 ); + + for ( int i = 0; i < ExpressionT::extent_0; ++i ) +#if defined( KOKKOS_ENABLE_PRAGMA_UNROLL ) +#pragma unroll +#endif + for ( int j = 0; j < ExpressionM::extent_0; ++j ) + for ( int k = 0; k < ExpressionM::extent_1; ++k ) + res( i ) += t_eval( i, j, k ) * m_eval( j, k ); + + return res; +} + template ::value && is_matrix::value,