diff --git a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc index b37557da7..813d16037 100644 --- a/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc +++ b/src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc @@ -7,22 +7,26 @@ namespace op::causal_softmax::ascend { struct Descriptor::Opaque { aclnnTensorDescriptor_t x; + aclnnTensorDescriptor_t temp; aclnnTensorDescriptor_t mask; aclnnTensorDescriptor_t y; aclnnTensorDescriptor_t value; void *mask_addr; void *value_addr; + void *temp_addr; uint64_t workspacesize; aclOpExecutor *executor; ~Opaque() { delete x; + delete temp; delete mask; delete y; delete value; aclrtFree(mask_addr); aclrtFree(value_addr); + aclrtFree(temp_addr); // Delete useless executor aclDestroyAclOpExecutor(executor); @@ -60,7 +64,18 @@ infiniStatus_t Descriptor::create( std::vector y_strides = {static_cast(info.y_stride_b), static_cast(info.y_stride_i), static_cast(info.y_stride_j)}; y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, y_strides); x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, x_strides); - mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.total_seq_len), 1}); + mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast(info.batch_size), static_cast(info.seq_len), static_cast(info.total_seq_len)}, {static_cast(info.seq_len * info.total_seq_len), static_cast(info.total_seq_len), 1}); + + // Allocate contiguous temp buffer for computation (avoids stride issues) + void *temp_addr = nullptr; + size_t temp_elements = info.batch_size * info.seq_len * info.total_seq_len; + size_t temp_bytes = temp_elements * aclDataTypeSize(toAclDataType(info.dtype)); + CHECK_ACL(aclrtMalloc(&temp_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + std::vector temp_strides = { + static_cast(info.seq_len * info.total_seq_len), + static_cast(info.total_seq_len), + 1}; + auto temp = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, temp_strides, temp_addr); // Initialize the value tensor with -∞ if (info.dtype == INFINI_DTYPE_F16) { @@ -69,6 +84,12 @@ infiniStatus_t Descriptor::create( CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT16, {}, {}); + } else if (info.dtype == INFINI_DTYPE_BF16) { + uint16_t mask_value = 0xff80; + auto size = aclDataTypeSize(aclDataType::ACL_BF16); + CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE)); + value = new aclnnTensorDescriptor(aclDataType::ACL_BF16, {}, {}); } else { uint32_t mask_value = 0xff800000; auto size = aclDataTypeSize(aclDataType::ACL_FLOAT); @@ -77,36 +98,39 @@ infiniStatus_t Descriptor::create( value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT, {}, {}); } - // Fill Mask Tensor - std::vector mask_matrix(mask->numel(), 0); + // Fill Mask Tensor (replicate 2D causal mask to all batches) + size_t mask_data_size = info.batch_size * info.seq_len * info.total_seq_len; + std::vector mask_matrix(mask_data_size, 0); for (size_t i = 0; i < info.seq_len; ++i) { for (size_t j = info.total_seq_len - info.seq_len + i + 1; j < info.total_seq_len; ++j) { - size_t index = i * info.total_seq_len + j; - mask_matrix[index] = 1; + size_t index_2d = i * info.total_seq_len + j; + for (size_t b = 0; b < info.batch_size; ++b) { + mask_matrix[b * info.seq_len * info.total_seq_len + index_2d] = 1; + } } } - auto size = mask->numel() * aclDataTypeSize(aclDataType::ACL_BOOL); + auto size = mask_data_size * aclDataTypeSize(aclDataType::ACL_BOOL); CHECK_ACL(aclrtMalloc(&mask_addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); CHECK_ACL(aclrtMemcpy(mask_addr, size, mask_matrix.data(), size, ACL_MEMCPY_HOST_TO_DEVICE)); // Get the workspace size for the op - aclTensor *tx = x->tensor; + aclTensor *ttemp = temp->tensor; aclTensor *ty = y->tensor; aclTensor *tmask = mask->tensor; aclTensor *tvalue = value->tensor; - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); int64_t dim = 2; - CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor)); + CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp, dim, ty, &workspacesize_softmax, &executor)); // set executor reusable aclSetAclOpExecutorRepeatable(executor); // Create the descripto size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask); - *desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr, + *desc_ptr = new Descriptor(new Opaque{x, temp, mask, y, value, mask_addr, value_addr, temp_addr, workspacesize_softmax, executor}, std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id); @@ -117,20 +141,38 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi if (workspace_size < workspaceSize()) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - auto tx = _opaque->x->tensor; + auto ttemp = _opaque->temp->tensor; auto ty = _opaque->y->tensor; auto tmask = _opaque->mask->tensor; auto tvalue = _opaque->value->tensor; aclOpExecutor *mask_executor = nullptr; size_t workspacesize_mask = 0; - AclSetTensorAddr(mask_executor, 0, tx, (void *)x); + // Copy x to contiguous temp buffer (handles custom stride correctly) + size_t dtype_sz = aclDataTypeSize(_opaque->temp->dataType); + size_t row_bytes = _info.total_seq_len * dtype_sz; + for (size_t b = 0; b < _info.batch_size; b++) { + size_t dst_batch_off = b * _info.seq_len * _info.total_seq_len * dtype_sz; + size_t src_batch_off = b * _info.x_stride_b * dtype_sz; + for (size_t i = 0; i < _info.seq_len; i++) { + aclrtMemcpy( + (char *)_opaque->temp_addr + dst_batch_off + i * _info.total_seq_len * dtype_sz, + row_bytes, + (const char *)x + src_batch_off + i * _info.x_stride_i * dtype_sz, + row_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE); + } + } + + // Masked fill on temp (contiguous, no stride issues) + AclSetTensorAddr(mask_executor, 0, ttemp, _opaque->temp_addr); AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr); AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr); - CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor)); + CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor)); CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream)); - AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x); + // Softmax temp (contiguous) → y + AclSetTensorAddr(_opaque->executor, 0, ttemp, _opaque->temp_addr); AclSetTensorAddr(_opaque->executor, 1, ty, y); CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->executor, stream));