diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index d5995939b0270..b852e3f1ff3f5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -29,10 +29,6 @@ namespace sycl { inline namespace _V1 { namespace ext { -namespace intel::experimental::matrix::layout { -constexpr sycl::ext::oneapi::experimental::matrix::layout packed = - static_cast(2); -} namespace oneapi { namespace experimental { namespace matrix { @@ -48,8 +44,7 @@ template struct spv_matrix_layout_traits { SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) -SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed, - __spv::MatrixLayout::Packed) +SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed) SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic) template struct spv_matrix_use_traits { @@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait< using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32; using storage_element_type = float; }; -} // namespace detail -} // namespace oneapi - -namespace intel::experimental::matrix { using namespace sycl::ext::oneapi::experimental::matrix; // Begin wi_element definition @@ -121,12 +112,12 @@ class wi_element { std::size_t i) : M(Mat), idx(i) {} - inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { + inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { #if defined(__SYCL_DEVICE_ONLY__) __ocl_vec_t coord = __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); - const uint32_t row = coord[0]; - const uint32_t col = coord[1]; + const size_t row = coord[0]; + const size_t col = coord[1]; return std::make_tuple(row, col); #else throw runtime_error("joint matrix is not supported on host device.", @@ -196,7 +187,7 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ static_cast( \ @@ -211,7 +202,7 @@ class wi_element { } #else // __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ PI_ERROR_INVALID_DEVICE); \ @@ -315,7 +306,7 @@ class wi_element = true> inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group, - sycl::ext::oneapi::experimental::matrix::joint_matrix< + const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout> &src, multi_ptr dst, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) @@ -526,6 +520,43 @@ joint_matrix_store(Group, PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) } + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( + Group sg, + sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, + F &&lambda) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { + lambda(jm.cuda_impl.wi_marray[i]); + } +#else // NVPTX + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); + for (int i = 0; i < wi_data_c.length(); i++) { + storage_element_type element = wi_data_c[i]; + auto [row, col] = wi_data_c[i].get_coord(); + lambda(element, row, col); + wi_data_c[i] = element; + } +#endif +#else + std::ignore = sg; + std::ignore = jm; + std::ignore = lambda; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif +} + } // namespace intel::experimental::matrix } // namespace ext diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index f51e146fd9a0c..8a9dbc12df2ec 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -16,7 +16,12 @@ namespace matrix { enum class use { a, b, accumulator }; -enum class layout { row_major = 0, col_major = 1, dynamic = 3 }; +enum class layout { + row_major = 0, + col_major = 1, + ext_intel_packed = 2, + dynamic = 3 +}; namespace precision { class tf32 { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index fe62089549b34..327e1e326f108 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,8 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_cuda + mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< @@ -61,19 +62,8 @@ struct joint_matrix { } #ifdef __SYCL_DEVICE_ONLY__ #if defined(__SPIR__) - // Generate a non-trivial assignment operator and copy c'tor that prevents - // memcpy from being generated. - // TODO: to remove, when either IGC can handle alloca JointMatrix or - // combination of InstCombine + SROA + mem2reg can remove it - joint_matrix(const joint_matrix &other) { - spvm = other.spvm; - return *this; - } - - joint_matrix &operator=(const joint_matrix &rhs) { - spvm = rhs.spvm; - return *this; - } + joint_matrix(const joint_matrix &other) = delete; + joint_matrix &operator=(const joint_matrix &rhs) = delete; #endif // defined(__SPIR__) #endif }; @@ -97,10 +87,6 @@ class wi_data { size_t length() { #if defined(__NVPTX__) return jm.cuda_impl.wi_marray.size(); -#else - throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", - PI_ERROR_INVALID_DEVICE); #endif }; @@ -109,9 +95,6 @@ class wi_data { return (jm.cuda_impl.wi_marray[i]); #else std::ignore = i; - throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", - PI_ERROR_INVALID_DEVICE); #endif }; }; @@ -139,9 +122,8 @@ template &jm, using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< T>::storage_element_type; - auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm); + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); for (int i = 0; i < wi_data_c.length(); i++) { storage_element_type element = wi_data_c[i]; lambda(element); @@ -260,7 +242,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case layout::ext_intel_packed: res.spvm = __spirv_JointMatrixLoadINTEL< DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, @@ -322,8 +304,9 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Group, - joint_matrix &src, + const joint_matrix + &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { #if defined(__SYCL_DEVICE_ONLY__) @@ -355,7 +338,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case layout::ext_intel_packed: __spirv_JointMatrixStoreINTEL< DecorT, T, NumRows, NumCols, spv_matrix_use_traits::value, @@ -375,51 +358,77 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( #endif // defined(__SYCL_DEVICE_ONLY__) } -template -inline __SYCL_ALWAYS_INLINE - joint_matrix - joint_matrix_mad( - Group, joint_matrix &A, - joint_matrix &B, - joint_matrix - &C) { +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( + Group, + joint_matrix &D, + const joint_matrix &A, + const joint_matrix &B, + const joint_matrix + &C) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) if constexpr (std::is_same::value) { - joint_matrix - D; sycl::ext::oneapi::detail::joint_matrix_mad_cuda( D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); - return D; } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); } #else - joint_matrix res; if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_signed::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_signed::value) - res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); else - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); - return res; + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #endif // defined(__NVPTX__) #else std::ignore = A; std::ignore = B; std::ignore = C; + std::ignore = D; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +void joint_matrix_copy( + Group sg, joint_matrix &src, + joint_matrix &dst) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { + dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + } +#else + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T2>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src); + auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst); + for (int i = 0; i < wi_data_c.length(); i++) { + wi_data_dst[i] = static_cast(wi_data_c[i]); + } +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = dst; + std::ignore = src; throw runtime_error("joint matrix is not supported on host device.", PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) diff --git a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 6559f4c93248d..0000000000000 --- a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix-xmx8 - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// this code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 8 -constexpr size_t TN = 8; - -#include "../element_wise_irreg_sum_rows_impl.hpp" diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 8b7ee3af2b9c5..378c46c4b84d5 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,8 +55,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; @@ -65,33 +64,21 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] += 1; - } + joint_matrix_apply(sg, sub_a, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] += 1; - } + joint_matrix_apply(sg, sub_b, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_c, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 1; - } + joint_matrix_apply(sg, sub_c, [](T1 &x) { x += 1; }); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 540d75c245815..42e1afb4d69f1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,11 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x + static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -76,11 +73,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x - static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +105,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x * static_cast(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +137,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x / static_cast(2.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -181,30 +169,25 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > static_cast(2.0) || - wi_slice_a[i] >= static_cast(2.0) || - wi_slice_a[i] < static_cast(2.0) || - wi_slice_a[i] <= static_cast(2.0)) { - T val = (wi_slice_a[i] != static_cast(2.0)) - ? wi_slice_a[i] - : static_cast(2.0); + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > static_cast(2.0) || x >= static_cast(2.0) || + x < static_cast(2.0) || x <= static_cast(2.0)) { + T val = + (x != static_cast(2.0)) ? x : static_cast(2.0); val--; val++; - if (wi_slice_a[i] == static_cast(2.0)) { + if (x == static_cast(2.0)) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 8e15488e151a0..b11d3093bf08d 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,12 +63,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } - + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_mat, accessMat.template get_multi_ptr() + @@ -104,11 +99,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); joint_matrix_store( sg, sub_mat, diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 803ebe0addb3a..4a43d39738657 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,11 +40,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -75,11 +71,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -110,11 +102,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * 3; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -145,11 +133,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,26 +164,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || - wi_slice_a[i] < 2 || wi_slice_a[i] <= 2) { - T val = (wi_slice_a[i] != 2) ? wi_slice_a[i] : 2; + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_a[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index ce89a04b4168c..e3d21a36bd6e1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -37,16 +37,12 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] + 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -74,16 +70,12 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] - 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -111,16 +103,12 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] * 3; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -148,16 +136,12 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] / 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -185,31 +169,28 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - if (wi_slice_b[i]) { - if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || - wi_slice_b[i] < 2 || wi_slice_b[i] <= 2) { - T val = (wi_slice_b[i] != 2) ? wi_slice_b[i] : 2; + joint_matrix_apply(sg, sub_b, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_b[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_b[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 27eacf89c748a..53b83a7f5d389 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -41,11 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x + round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, @@ -77,11 +74,9 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x - round_to_tf32(2); }); + ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +106,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x * round_to_tf32(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +138,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x / round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,27 +169,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || - wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) { - Ts val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i] : 2.0; - val = val - static_cast(1); - val = val + static_cast(1); - if (wi_slice_a[i] == 2.0) { - val = val - static_cast(2); - val = val * static_cast(3); - val = val / static_cast(2); - + joint_matrix_apply(sg, sub_a, [&](float &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + float val = (x != 2) ? x : 2; + val--; + val++; + if (x == 2) { + val -= 2; + val *= 3; + val /= 2; } else { - val = val + static_cast(2); + val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index c49f9b57e2f32..6e1b6410547ad 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,30 +49,26 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( - SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + joint_matrix_apply(sg, sub_a, [=](T &x) { x += val2; }); + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 1cb48f1bc4f72..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// This code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 16 -constexpr size_t TN = 16; - -#include "element_wise_irreg_sum_rows_impl.hpp" diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp deleted file mode 100644 index cfce95cba269f..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ /dev/null @@ -1,107 +0,0 @@ -#define TK 32 - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void sum_rows_ref(host_accessor B, - host_accessor sum_rows) { - int sum_rows_ref[M] = {0}; - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - sum_rows_ref[i] += B[i][j]; - } - auto diff = sum_rows[i] - sum_rows_ref[i]; - assert(std::fabs(static_cast(diff)) <= - std::numeric_limits::epsilon()); - } -} - -template -void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { - buffer bufB(B.get_data(), range<2>(M, N)); - // size of vector is known because SG size of set by the user in this case - int sum_rows[M] = {0}; - buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows - q.submit([&](handler &cgh) { - auto accB = bufB.get_access(cgh); - - auto v = sum_rows_v.get_access(cgh); - - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix - sub_b; - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N); - // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b - // (tK/4) - int32_t sum_local_rows[M] = {0}; // 8 local rows, M total - // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - - // each WI calculates local sum of rows - for (int row = 0; row < TK / 4; row++) { // there are 8 rows - for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row - // i*SG_SIZE index is found based on the round robin - // distribution we are using in the implementation - sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; - } - sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( - sg, sum_local_rows[row + global_idx * (TK / 4)], - sycl::plus<>()); - - // only Groups leader perform the global reduction - if (global_idy % SG_SZ == 0) { - atomic_fetch_add(v[row + global_idx * (TK / 4)], - sum_local_rows[row + global_idx * (TK / 4)]); - } - } - }); // parallel for - }).wait(); - sum_rows_ref(bufB.get_host_access(read_only), - sum_rows_v.get_host_access(read_only)); -} - -static constexpr size_t MATRIX_K = TK / 4 * 2; -static constexpr size_t MATRIX_N = TN * 4 * 2; -int8_t B[MATRIX_K][MATRIX_N]; - -int main() { - big_matrix MB((int8_t *)&B); - - size_t NDRangeK = MATRIX_K / (TK / 4); - size_t NDRangeN = (MATRIX_N / 4) / TN; - queue q; - nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); - - for (int i = 0; i < MATRIX_K; i++) { - for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; - } - } - - matrix_sum_rows(q, MB, r); - - return 0; -} diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1206b556339a9..1dd9779aa0b56 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,13 +72,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x = x * 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 8e2865de207b4..cc8722467d262 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load( @@ -101,13 +101,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 5.0; + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } + joint_matrix_apply(sg, sub_c, [](float &x) { x += 5.0; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp index dea7601437742..f9d19e914e639 100644 --- a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp @@ -50,15 +50,11 @@ void matrix_sum_rows(big_matrix &C, float *sum_rows) { N, layout::row_major); float sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - - for (int i = 0; i < data.length(); ++i) { - auto dataItem = data[i]; - auto [row, col] = dataItem.get_coord(); - sum_local_rows[row + global_idx * TM] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_c, [&](float &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; i++) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp index ec21cfa036807..619f97969b29c 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp @@ -96,16 +96,11 @@ void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { K); int32_t sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - - // each WI calculates local sum of rows - for (int i = 0; i < data.length(); ++i) { - auto data_item = data[i]; - auto [row, col] = data_item.get_coord(); - sum_local_rows[row + global_idx * TM] += data_item; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_a, [&](int8_t &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; ++i) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index dfbad2d19f946..85f080db149eb 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -113,8 +113,7 @@ void matrix_sum_cols(queue q, big_matrix &B, sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix sub_b; joint_matrix_load(sg, sub_b, @@ -124,22 +123,16 @@ void matrix_sum_cols(queue q, big_matrix &B, N * VF); int32_t sum_local_cols[N] = {0}; - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - // the coordinates returned are in the logical range [K,N] - // If users want to retrieve the VNNIed coordinates, they can be - // obtained using - // colVNNI = col/VF - // rowVNNI = row*VF - auto [row, col] = dataItem.get_coord(); - size_t global_index = col + global_idy / SG_SZ * TN; - sum_local_cols[global_index] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, [&](int8_t &x, size_t row, size_t col) { + // the coordinates returned are in the logical range [K,N] + // If users want to retrieve the VNNIed coordinates, they can be + // obtained using + // colVNNI = col/VF + // rowVNNI = row*VF + size_t global_index = col + global_idy / SG_SZ * TN; + sum_local_cols[global_index] += x; + }); for (int i = 0; i < N; i++) { sum_local_cols[i] = @@ -178,4 +171,4 @@ int main() { matrix_sum_cols(q, MB, MBvnni, r); std::cout << "Passed\n"; return 0; -} +} \ No newline at end of file diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 469837cd26d41..00149e0b55ce4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,8 +56,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; @@ -78,7 +77,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index d6aecdd299bef..8a6c6672e0a5b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -53,34 +53,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>(r, [ - accC, lambda - ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>( + r, [accC, lambda]( + nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -113,8 +113,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [ accC, - Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -156,7 +156,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, @@ -164,8 +164,7 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index f699891176ea7..59836622ac1d7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -168,35 +168,27 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { ; joint_matrix + layout::ext_intel_packed> tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), } #endif ; @@ -248,8 +240,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { for (unsigned int n = 0; n < NCACHE1 / tN; n++) { #endif - tC[m][n] = - joint_matrix_mad(sg, tA[m][k1], tB[n][k1], tC[m][n]); + joint_matrix_mad(sg, tC[m][n], tA[m][k1], tB[n][k1], + tC[m][n]); #ifdef MANUAL_UNROLL }); // n }); // m diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index cc0196660744a..40cc2ad58bdc6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -47,7 +47,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d6390d8061dcc..671cf78b660a1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -60,7 +60,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); + joint_matrix_mad(sg, sub_c[i], sub_a[i], sub_b, sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 4847e093127a8..7c07afcb3ecb7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 76ac69da27677..ddf731fba24a3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -46,7 +46,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -66,7 +66,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 4d61a733e5927..119554c9b23ad 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index f75da1824d94b..f24f720715788 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 6972e3854c8e8..68e9d3c145675 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -48,15 +48,7 @@ void matrix_copy(big_matrix &C, big_matrix &A) { accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - // This will be replaced by joint_matrix_copy API - // joint_matrix_copy(sg, sub_c, sub_ac); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_a[i] = (bfloat16)wi_slice_c[i]; - } + joint_matrix_copy(sg, sub_c, sub_a); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 5e451d45d7727..219a3976f4c90 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index f92548d2f7ed8..c7a09229063eb 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -51,7 +51,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -71,7 +71,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6111f503007f5..d2081f01ec167 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index b6fe3f0376ffd..f4f4d682930a4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 1c1c4f97819bf..51ea6745a8174 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -44,7 +44,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; // bounds-checked load where width and height are added @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 112a744db6e3b..760b9961050b0 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -79,9 +79,7 @@ void matrix_multiply(big_matrix &C, sycl::sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a; - myparams2::joint_matrix_b< - sub_group, ext::intel::experimental::matrix::layout::packed> - sub_b; + myparams2::joint_matrix_b sub_b; myparams2::joint_matrix_c sub_c; joint_matrix_load( @@ -101,7 +99,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5cca6572cef21..8135897f893f9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 397fcc9a5aa97..2730f0f6184de 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 4d4ba0ee951e9..607aba535c74f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,15 +76,11 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_apply(sg, sub_b, + [=](float x) { x = round_to_tf32(x); }); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 24f6cce4cc09d..02bad19d0d4f4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -43,7 +43,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, @@ -55,7 +55,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 1d82f8833aba6..47c9d82e18479 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -54,7 +54,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -75,7 +75,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index e400b6694e4a9..c132aeafef9d2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 80b67b14a55ac..309786a38003f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 31f77dc55b16f..16603407d74b1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 7c24179022d55..47ddc0fb42f48 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0e0b4ce903be2..0468f592b6427 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 575039723d56e..858c8625cc6e9 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 8ada375fff395..f47a701fe7bc6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 69bc136e79776..c6a1bda15cdcb 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index e5935b8b3af47..22c8203444ab4 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -30,7 +30,7 @@ int main(void) { layout::row_major> tA; joint_matrix + layout::ext_intel_packed> tB; joint_matrix tC; @@ -49,7 +49,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tC, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index ceb741c04982a..ee6d37654184e 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -157,7 +157,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // TK = 32, TN = 16 joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( @@ -168,8 +168,6 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. @@ -177,19 +175,15 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // Keep track of cols handled in this WI int32_t handled_cols[N] = {-1}; - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - - // Calculation of global index - int sg_idx = (int)global_idy / SG_SZ; - global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; - sum_local_cols[global_index] += wiData[i]; - handled_cols[global_index] = 1; - } - + sycl::ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, + [&](int8_t &x, size_t row, + size_t col) { // Calculation of global index + int sg_idx = (int)global_idy / SG_SZ; + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; + sum_local_cols[global_index] += x; + handled_cols[global_index] = 1; + }); for (int j = 0; j < N; j++) { if (handled_cols[j] == 1) { global_index = j; diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 2e0e309081464..37dc5a1607631 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 3205e4c346ba6..9621f570cf461 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,7 +69,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -93,13 +93,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - auto wi_data_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_data_c.length(); i++) { - wi_data_c[i] *= 2; + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x *= 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index f8dcc26ab1b17..c4ab58c1deaec 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,7 +74,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -94,7 +94,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index d6affb4067003..496af7dabd335 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,12 +87,9 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_apply(sg, sub_b, + [=](float &x) { x = round_to_tf32(x); }); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c,