Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,20 @@ q.submit([&](sycl::handler& cgh) {
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
joint_matrix_fill(sg, tC, 0);
for (int k = 0; k < K; k += tK) {
joint_matrix_load(sg, tA, accA + sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB, accB + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
tC = joint_matrix_mad(sg, tA, tB, tC);
joint_matrix_load(sg, tA,
accA.template get_multi_ptr<sycl::access::decorated::no>() +
sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB,
accB.template get_multi_ptr<sycl::access::decorated::no>() +
k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
joint_matrix_mad(sg, tC, tA, tB, tC);
}
auto wi_data_c = ext::intel::experimental::matrix::get_wi_data(sg, tC);
for (int i = 0; i < wi_data_c.length(); i++)
wi_data_c[i] *= alpha;
joint_matrix_apply(sg, tC, [=](int8_t x) {
x *= alpha;
});
joint_matrix_store(sg, tC,
accC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
accC.template get_multi_ptr<sycl::access::decorated::no>()
+ sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
});
});
q.wait();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout.
```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename Ta, typename Tb, typename Tc,
std::size_t M, std::size_t K, std::size_t N, layout LayoutA, layout
LayoutB, typename Td = Tc>
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic>
joint_matrix_mad(Group g,
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
std::size_t M, std::size_t K, std::size_t N,
layout LayoutA, layout LayoutB>
void joint_matrix_mad(Group g,
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic> &D,
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);
Expand All @@ -287,7 +287,7 @@ joint_matrix_mad(Group g,
```
The matrix multiply and add function performs the multiply operation
on the matrices `A` and `B`, accumulates the result with `C` and returns
the result.
the result into the matrix `D`.

Each device supports only certain combinations of types for the `A`,
`B`, and `C` matrices. The application must use the query operations
Expand Down Expand Up @@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE};
int8_t *memA = malloc_shared<int8_t>(M*K, q);
int8_t *memB = malloc_shared<int8_t>(K*N, q);
int32_t *memC = malloc_shared<int32_t>(M*N, q);
auto pA = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memA);
auto pB = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memB);
auto pC = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memC);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
[[sycl::reqd_sub_group_size(SG_SIZE)]] {
const auto global_idx = item.get_global_id(0);
Expand All @@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
joint_matrix_fill(sg, tC, 0);
for (int k = 0; k < K; k += tK) {
joint_matrix_load(sg, tA,
multi_ptr<int8_t, sycl::access::address_space::global_space>(memA) +
sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB,
multi_ptr<int8_t, sycl::access::address_space::global_space>(memB) +
k * N + sg_starty/SG_SIZE*tN, N);
tC = joint_matrix_mad(sg, tA, tB, tC);
joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
joint_matrix_mad(sg, tC, tA, tB, tC);
}
joint_matrix_apply(sg, tC, [=](int8_t x) {
x *= alpha;
});
joint_matrix_store(sg, tC,
multi_ptr<int32_t, sycl::access::address_space::global_space>(memC) +
sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
N, layout::row_major);
}).wait();
```

Expand Down