diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c3e5a9eccc3aa..992d54c36062a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5366,7 +5366,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { device->pinned_memory.erase(device->pinned_memory.begin() + index); } -static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { +static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { std::lock_guard guard(device->mutex); buf = nullptr; buf_offset = 0; @@ -5381,6 +5381,39 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf } } +static vk_subbuffer ggml_vk_tensor_subbuffer( + const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false, + vk_buffer buffer = nullptr, size_t offset = 0) { + + if (!buffer) { + auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + buffer = buf_ctx->dev_buffer; + offset = vk_tensor_offset(tensor) + tensor->view_offs; + } + GGML_ASSERT(buffer != nullptr); + + size_t size = ggml_nbytes(tensor); + + size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + // The shader must support misaligned offsets when indexing into the buffer + GGML_ASSERT(allow_misalign || misalign_bytes == 0); + offset &= ~misalign_bytes; + size += misalign_bytes; + + return vk_subbuffer{buffer, offset, size}; +} + +static vk_subbuffer ggml_vk_tensor_subbuffer_uma( + const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) { + + vk_buffer buffer = nullptr; + size_t offset = 0; + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensor->data, buffer, offset); + } + return ggml_vk_tensor_subbuffer(ctx, tensor, allow_misalign, std::move(buffer), offset); +} + static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; s.buffer = ggml_vk_create_cmd_buffer(device, p); @@ -7739,72 +7772,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; - size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; - - bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); - ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); - ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); - Q_uma = d_Q != nullptr; - K_uma = d_K != nullptr; - V_uma = d_V != nullptr; - D_uma = d_D != nullptr; - if (mask) { - ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); - M_uma = d_M != nullptr; - } - if (sinks) { - ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); - S_uma = d_S != nullptr; - } - } - - - ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; - ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; - ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; - - if (!Q_uma) { - d_Q = q_buf_ctx->dev_buffer; - q_buf_offset = vk_tensor_offset(q) + q->view_offs; - } - if (!K_uma) { - d_K = k_buf_ctx->dev_buffer; - k_buf_offset = vk_tensor_offset(k) + k->view_offs; - } - if (!V_uma) { - d_V = v_buf_ctx->dev_buffer; - v_buf_offset = vk_tensor_offset(v) + v->view_offs; - } - if (!D_uma) { - d_D = d_buf_ctx->dev_buffer; - d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; - } - - if (!M_uma) { - d_M = d_Q; - m_buf_offset = q_buf_offset; - if (mask) { - ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; - d_M = m_buf_ctx->dev_buffer; - m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; - } - } - - if (!S_uma) { - d_S = d_Q; - s_buf_offset = q_buf_offset; - if (sinks) { - ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; - d_S = s_buf_ctx->dev_buffer; - s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; - } - } + vk_subbuffer q_buf = ggml_vk_tensor_subbuffer_uma(ctx, q); + vk_subbuffer k_buf = ggml_vk_tensor_subbuffer_uma(ctx, k); + vk_subbuffer v_buf = ggml_vk_tensor_subbuffer_uma(ctx, v); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst); + vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer_uma(ctx, mask) : q_buf; + vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer_uma(ctx, sinks) : q_buf; uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; @@ -7826,15 +7799,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_sync_buffers(ctx, subctx); } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), - ggml_vk_subbuffer(ctx, d_K, k_buf_offset), - ggml_vk_subbuffer(ctx, d_V, v_buf_offset), - ggml_vk_subbuffer(ctx, d_M, m_buf_offset), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), - }, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to @@ -7844,23 +7811,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, - { - ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, d_D, d_buf_offset), - }, + {split_k_buf, sinks_buf, dst_buf}, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), - ggml_vk_subbuffer(ctx, d_K, k_buf_offset), - ggml_vk_subbuffer(ctx, d_V, v_buf_offset), - ggml_vk_subbuffer(ctx, d_M, m_buf_offset), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, d_D, d_buf_offset), - }, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } @@ -8543,35 +8499,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ne01 = src0->ne[1]; const uint64_t ne02 = src0->ne[2]; const uint64_t ne03 = src0->ne[3]; - const uint64_t ne0 = ne00 * ne01; const bool use_src1 = src1 != nullptr; const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; - const uint64_t ne1 = ne10 * ne11; - // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; const bool use_src2 = src2 != nullptr; - const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; - const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; - const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; - const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; - const uint64_t ne2 = ne20 * ne21; - const bool use_src3 = src3 != nullptr; - const uint64_t ne30 = use_src3 ? src3->ne[0] : 0; - const uint64_t ne31 = use_src3 ? src3->ne[1] : 0; - const uint64_t ne32 = use_src3 ? src3->ne[2] : 0; - const uint64_t ne33 = use_src3 ? src3->ne[3] : 0; - const uint64_t ne3 = ne30 * ne31; - - const uint64_t ned0 = dst->ne[0]; - const uint64_t ned1 = dst->ne[1]; - const uint64_t ned2 = dst->ne[2]; - const uint64_t ned3 = dst->ne[3]; - const uint64_t ned = ned0 * ned1; init_pushconst_fastdiv(pc); @@ -8593,74 +8529,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; - ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; - ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr; - - vk_buffer d_X = nullptr; - size_t x_buf_offset = 0; - vk_buffer d_Y = nullptr; - size_t y_buf_offset = 0; - vk_buffer d_Z = nullptr; - size_t z_buf_offset = 0; - vk_buffer d_W = nullptr; - size_t w_buf_offset = 0; + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer_uma(ctx, src0, op_supports_incontiguous); + vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer_uma(ctx, src1, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer_uma(ctx, src2, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer_uma(ctx, src3, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous); - bool src0_uma = false; - bool src1_uma = false; - bool src2_uma = false; - bool src3_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); - src0_uma = d_X != nullptr; - if (use_src1) { - ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); - src1_uma = d_Y != nullptr; - } - if (use_src2) { - ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); - src2_uma = d_Z != nullptr; - } - if (use_src3) { - ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset); - src3_uma = d_W != nullptr; - } - } - - vk_buffer d_D = dst_buf_ctx->dev_buffer; - - GGML_ASSERT(d_D != nullptr); - uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; - if(!src0_uma) { - d_X = src0_buf_ctx->dev_buffer; - x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; - GGML_ASSERT(d_X != nullptr); - } - if (use_src1 && !src1_uma) { - d_Y = src1_buf_ctx->dev_buffer; - y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; - GGML_ASSERT(d_Y != nullptr); - } - if (use_src2 && !src2_uma) { - d_Z = src2_buf_ctx->dev_buffer; - z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; - GGML_ASSERT(d_Z != nullptr); - } - if (use_src3 && !src3_uma) { - d_W = src3_buf_ctx->dev_buffer; - w_buf_offset = vk_tensor_offset(src3) + src3->view_offs; - GGML_ASSERT(d_W != nullptr); - } - // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + // Compute misalignment offset for descriptors and store it in in push constants. init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); - x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); std::array elements; @@ -8744,9 +8620,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t KH = ne01; const uint32_t KW = ne00; - const uint32_t OD = ned3 / N; - const uint32_t OH = ned2; - const uint32_t OW = ned1; + const uint32_t OD = dst->ne[3] / N; + const uint32_t OH = dst->ne[2]; + const uint32_t OW = dst->ne[1]; const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; const uint32_t N_OD_OH = N*OD*OH; @@ -8861,112 +8737,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; } - uint64_t x_sz, y_sz, z_sz, w_sz, d_sz; - - if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); - y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; - w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0; - d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); - - if (x_buf_offset + x_sz >= d_X->size) { - x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); - } - if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { - y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); - } - if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { - z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); - } - if (use_src3 && w_buf_offset + w_sz >= d_W->size) { - w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset); - } - if (d_buf_offset + d_sz >= d_D->size) { - d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); - } - } else { - x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; - y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; - z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; - w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0; - d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; - } - if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { - vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; - size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + vk_subbuffer a_buf = src0_buf; + if (ctx->do_add_rms_partials) { + a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset); + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { vk_subbuffer{ d_X, x_buf_offset, x_sz }, - vk_subbuffer{ d_Y, y_buf_offset, y_sz }, - vk_subbuffer{ d_D, d_buf_offset, d_sz }, - ggml_vk_subbuffer(ctx, d_A, a_buf_offset), - }, pc, elements); + { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements); } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer - vk_subbuffer subbuf_y; - if (use_src1) { - subbuf_y = { d_Y, y_buf_offset, y_sz }; - } else { - subbuf_y = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements); } else if (op == GGML_OP_SOFT_MAX) { // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer - vk_subbuffer subbuf_y; - if (use_src1) { - subbuf_y = { d_Y, y_buf_offset, y_sz }; - } else { - subbuf_y = { d_X, 0, x_sz }; - } - - vk_subbuffer subbuf_z; - if (use_src2) { - subbuf_z = { d_Z, z_buf_offset, z_sz }; - } else { - subbuf_z = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; + vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { - // Empty src2 is possible in rope, but the shader needs a buffer - vk_subbuffer subbuf_z, subbuf_w; - if (use_src2) { - subbuf_z = { d_Z, z_buf_offset, z_sz }; - } else { - subbuf_z = { d_X, 0, x_sz }; - } - if (use_src3) { - subbuf_w = { d_W, w_buf_offset, w_sz }; - } else { - subbuf_w = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements); + // Empty src2 and src3 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; + vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements); } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { // buffer device address path doesn't use dst buffer - d_sz = 1; + dst_buf.size = 1; } // im2col uses only src1 and dst buffers - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { // count_equal assumes that destination buffer is initialized with zeroes - ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size); ggml_vk_sync_buffers(ctx, subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); } else if (op == GGML_OP_OPT_STEP_SGD) { // OPT_STEP_SGD works on src0, it does not need dst - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements); } else if (use_src3) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements); } else if (use_src2) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements); } else if (use_src1) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); } else { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements); } } @@ -9208,39 +9022,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx return; } - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; - for (int i = 0; i < num_srcs; i++) { - src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; - } - - vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; - size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; - bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; - - if (ctx->device->uma) { - for (int i = 0; i < num_srcs; i++) { - ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); - srcs_uma[i] = d_srcs[i] != nullptr; - } - - ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - dst_uma = d_D != nullptr; - } - - uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst); + vk_subbuffer src_buf[7] = {}; for (int i = 0; i < num_srcs; i++) { - src_sizes[i] = ggml_nbytes(dst->src[i]); - if (!srcs_uma[i]) { - d_srcs[i] = src_buf_ctxs[i]->dev_buffer; - src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; - } - } - - const uint64_t dst_size = ggml_nbytes(dst); - if (!dst_uma) { - d_D = dst_buf_ctx->dev_buffer; - dst_offset = vk_tensor_offset(dst) + dst->view_offs; + src_buf[i] = ggml_vk_tensor_subbuffer_uma(ctx, dst->src[i]); } std::array elements = { @@ -9250,26 +9035,13 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx }; if (version == 6) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); } else if (version == 7) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } else { // shouldn't happen GGML_ASSERT(false); @@ -9354,40 +9126,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, n_head, head_dim, n_group, n_tok }; - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC]; - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; - } - - vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr }; - size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 }; - bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false }; - - if (ctx->device->uma) { - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); - srcs_uma[i] = d_srcs[i] != nullptr; - } - ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - dst_uma = d_D != nullptr; - } - - if (!dst_uma) { - d_D = dst_buf_ctx->dev_buffer; - dst_offset = vk_tensor_offset(dst) + dst->view_offs; - } - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - if (!srcs_uma[i]) { - d_srcs[i] = src_buf_ctxs[i]->dev_buffer; - src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; - } - } - - size_t dst_size = ggml_nbytes(dst); - size_t src_sizes[GGML_MAX_SRC]; - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - src_sizes[i] = ggml_nbytes(dst->src[i]); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst); + vk_subbuffer src_buf[7] = {}; + for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer_uma(ctx, dst->src[i]); } std::array elements; @@ -9397,16 +9139,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -9456,66 +9191,17 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont return; } - ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; - ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; - ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; - ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; - ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; - - vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; - size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; - bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); - ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); - ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); - ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); - ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); - - X_uma = d_X != nullptr; - G_uma = d_G != nullptr; - GM_uma = d_GM != nullptr; - GV_uma = d_GV != nullptr; - P_uma = d_P != nullptr; - } - - if (!X_uma) { - d_X = x_buf_ctx->dev_buffer; - x_offset = vk_tensor_offset(x) + x->view_offs; - } - if (!G_uma) { - d_G = g_buf_ctx->dev_buffer; - g_offset = vk_tensor_offset(g) + g->view_offs; - } - if (!GM_uma) { - d_GM = gm_buf_ctx->dev_buffer; - gm_offset = vk_tensor_offset(gm) + gm->view_offs; - } - if (!GV_uma) { - d_GV = gv_buf_ctx->dev_buffer; - gv_offset = vk_tensor_offset(gv) + gv->view_offs; - } - if (!P_uma) { - d_P = p_buf_ctx->dev_buffer; - p_offset = vk_tensor_offset(p) + p->view_offs; - } - - const uint64_t x_size = ggml_nbytes(x); - const uint64_t g_size = ggml_nbytes(g); - const uint64_t gm_size = ggml_nbytes(gm); - const uint64_t gv_size = ggml_nbytes(gv); - const uint64_t p_size = ggml_nbytes(p); + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer_uma(ctx, x); + vk_subbuffer g_buf = ggml_vk_tensor_subbuffer_uma(ctx, g); + vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer_uma(ctx, gm); + vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer_uma(ctx, gv); + vk_subbuffer p_buf = ggml_vk_tensor_subbuffer_uma(ctx, p); std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_X, x_offset, x_size }, - vk_subbuffer{ d_G, g_offset, g_size }, - vk_subbuffer{ d_GM, gm_offset, gm_size }, - vk_subbuffer{ d_GV, gv_offset, gv_size }, - vk_subbuffer{ d_P, p_offset, p_size }, - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {x_buf, g_buf, gm_buf, gv_buf, p_buf}, + pc, elements); } static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -9851,45 +9537,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, return; } - ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context; - ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context; - ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; - - vk_buffer d_logits = nullptr; - size_t logits_buf_offset = 0; - vk_buffer d_weights = nullptr; - size_t weights_buf_offset = 0; - vk_buffer d_ids = nullptr; - size_t ids_buf_offset = 0; - - bool logits_uma = false; - bool weights_uma = false; - bool ids_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset); - ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset); - ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); - logits_uma = d_logits != nullptr; - weights_uma = d_weights != nullptr; - ids_uma = d_ids != nullptr; - } - - if (!logits_uma) { - d_logits = logits_buf_ctx->dev_buffer; - logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs; - GGML_ASSERT(d_logits != nullptr); - } - if (!weights_uma) { - d_weights = weights_buf_ctx->dev_buffer; - weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs; - GGML_ASSERT(d_weights != nullptr); - } - if (!ids_uma) { - d_ids = ids_buf_ctx->dev_buffer; - ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; - GGML_ASSERT(d_ids != nullptr); - } + vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer_uma(ctx, logits); + vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer_uma(ctx, weights); + vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer_uma(ctx, ids); vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; @@ -9905,12 +9555,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t rows_per_block = 4; std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset), - ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset), - ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset), - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop, bool dryrun = false) {