From b25eaa95da37b3909b47e46c88aacaa08f8c0456 Mon Sep 17 00:00:00 2001 From: ShaneWu Date: Sat, 13 Jun 2026 15:32:35 +0800 Subject: [PATCH 1/2] fix causal_softmax op to adapt with ascend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改 causal_softmax 算子在昇腾 Ascend 上对 x tensor 非默认 batch stride 和 BF16 dtype 进行适配。 修改内容:修改 Infinicore/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc 1、新增 BF16 值张量支持; 2、引入连续 temp 缓冲(batch×seq×total_seq_len),每条 calculate 将 x 逐行拷贝到 temp 后在 temp 上完成 masked_fill 和 softmax,最终写入 y,绕过 aclnnInplaceMaskedFillTensor 无法处理 x_tensor 非默认 batch stride 的限制 现状: infiniop算子测试全部通过;infinicore算子接口测试跑通,53/63 Passed --- .../ascend/causal_softmax_ascend.cc | 71 +++++++++++++++---- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc index b37557da7..4bfb9a233 100644 --- a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc +++ b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc @@ -7,22 +7,26 @@ namespace op::causal_softmax::ascend { struct Descriptor::Opaque { aclnnTensorDescriptor_t x; + aclnnTensorDescriptor_t temp; aclnnTensorDescriptor_t mask; aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t value; void *mask_addr; void *value_addr; + void *temp_addr; uint64_t workspacesize; aclOpExecutor *executor; ~Opaque() { delete x; + delete temp; delete mask; delete y; delete value; aclrtFree(mask_addr); aclrtFree(value_addr); + aclrtFree(temp_addr); // Delete useless executor aclDestroyAclOpExecutor(executor); @@ -60,7 +64,19 @@ infiniStatus_t Descriptor::create( std::vector y_strides = {static_cast(info.y_stride_b), static_cast(info.y_stride_i), static_cast(info.y_stride_j)}; y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, y_strides); x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, x_strides); - mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.total_seq_len), 1}); + mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.batch_size), static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), 1}); + + // Allocate contiguous temp buffer for computation (avoids stride issues) + void *temp_addr = nullptr; + size_t temp_elements = info.batch_size * info.seq_len * info.total_seq_len; + size_t temp_bytes = temp_elements * aclDataTypeSize(toAclDataType(info.dtype)); + CHECK_ACL(aclrtMalloc(&temp_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + std::vector temp_strides = { + static_cast(info.seq_len * info.total_seq_len), + static_cast(info.total_seq_len), + 1 + }; + auto temp = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, temp_strides, temp_addr); // Initialize the value tensor with -∞ if (info.dtype == INFINI_DTYPE_F16) { @@ -69,6 +85,12 @@ infiniStatus_t Descriptor::create( CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT16, {}, {}); + } else if (info.dtype == INFINI_DTYPE_BF16) { + uint16_t mask_value = 0xff80; + auto size = aclDataTypeSize(aclDataType::ACL_BF16); + CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); + value = new aclnnTensorDescriptor(aclDataType::ACL_BF16, {}, {}); } else { uint32_t mask_value = 0xff800000; auto size = aclDataTypeSize(aclDataType::ACL_FLOAT); @@ -77,36 +99,39 @@ infiniStatus_t Descriptor::create( value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT, {}, {}); } - // Fill Mask Tensor - std::vector mask_matrix(mask->numel(), 0); + // Fill Mask Tensor (replicate 2D causal mask to all batches) + size_t mask_data_size = info.batch_size * info.seq_len * info.total_seq_len; + std::vector mask_matrix(mask_data_size, 0); for (size_t i = 0; i < info.seq_len; ++i) { for (size_t j = info.total_seq_len - info.seq_len + i + 1; j < info.total_seq_len; ++j) { - size_t index = i * info.total_seq_len + j; - mask_matrix[index] = 1; + size_t index_2d = i * info.total_seq_len + j; + for (size_t b = 0; b < info.batch_size; ++b) { + mask_matrix[b * info.seq_len * info.total_seq_len + index_2d] = 1; + } } } - auto size = mask->numel() * aclDataTypeSize(aclDataType::ACL_BOOL); + auto size = mask_data_size * aclDataTypeSize(aclDataType::ACL_BOOL); CHECK_ACL(aclrtMalloc(&mask_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(mask_addr, size, mask_matrix.data(), size, ACL_MEMCPY_HOST_TO_DEVICE)); // Get the workspace size for the op - aclTensor *tx = x->tensor; + aclTensor *ttemp = temp->tensor; aclTensor *ty = y->tensor; aclTensor *tmask = mask->tensor; aclTensor *tvalue = value->tensor; - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); int64_t dim = 2; - CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp, dim, ty, &workspacesize_softmax, &executor)); // set executor reusable aclSetAclOpExecutorRepeatable(executor); // Create the descripto size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask); - *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr, + *desc_ptr = new Descriptor(new Opaque{x, temp, mask, y, value, mask_addr, value_addr, temp_addr, workspacesize_softmax, executor}, std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id); @@ -117,20 +142,38 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi if (workspace_size < workspaceSize()) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - auto tx = _opaque->x->tensor; + auto ttemp = _opaque->temp->tensor; auto ty = _opaque->y->tensor; auto tmask = _opaque->mask->tensor; auto tvalue = _opaque->value->tensor; aclOpExecutor *mask_executor = nullptr; size_t workspacesize_mask = 0; - AclSetTensorAddr(mask_executor, 0, tx, (void *)x); + // Copy x to contiguous temp buffer (handles custom stride correctly) + size_t dtype_sz = aclDataTypeSize(_opaque->temp->dataType); + size_t row_bytes = _info.total_seq_len * dtype_sz; + for (size_t b = 0; b < _info.batch_size; b++) { + size_t dst_batch_off = b * _info.seq_len * _info.total_seq_len * dtype_sz; + size_t src_batch_off = b * _info.x_stride_b * dtype_sz; + for (size_t i = 0; i < _info.seq_len; i++) { + aclrtMemcpy( + (char *)_opaque->temp_addr + dst_batch_off + i * _info.total_seq_len * dtype_sz, + row_bytes, + (const char *)x + src_batch_off + i * _info.x_stride_i * dtype_sz, + row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } + + // Masked fill on temp (contiguous, no stride issues) + AclSetTensorAddr(mask_executor, 0, ttemp, _opaque->temp_addr); AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream)); - AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x); + // Softmax temp (contiguous) → y + AclSetTensorAddr(_opaque->executor, 0, ttemp, _opaque->temp_addr); AclSetTensorAddr(_opaque->executor, 1, ty, y); CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->executor, stream)); From a6b68047749134f00f7d40813e52ccd58640b0ad Mon Sep 17 00:00:00 2001 From: ShaneWu Date: Sat, 13 Jun 2026 16:54:58 +0800 Subject: [PATCH 2/2] format causal_softmax_ascend.cc --- .../ops/causal_softmax/ascend/causal_softmax_ascend.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc index 4bfb9a233..813d16037 100644 --- a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc +++ b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc @@ -74,8 +74,7 @@ infiniStatus_t Descriptor::create( std::vector temp_strides = { static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), - 1 - }; + 1}; auto temp = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, temp_strides, temp_addr); // Initialize the value tensor with -∞