Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions src/infiniop/ops/causal_softmax/ascend/causal_softmax_ascend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@ namespace op::causal_softmax::ascend {

struct Descriptor::Opaque {
aclnnTensorDescriptor_t x;
aclnnTensorDescriptor_t temp;
aclnnTensorDescriptor_t mask;
aclnnTensorDescriptor_t y;
aclnnTensorDescriptor_t value;
void *mask_addr;
void *value_addr;
void *temp_addr;
uint64_t workspacesize;
aclOpExecutor *executor;

~Opaque() {
delete x;
delete temp;
delete mask;
delete y;
delete value;

aclrtFree(mask_addr);
aclrtFree(value_addr);
aclrtFree(temp_addr);

// Delete useless executor
aclDestroyAclOpExecutor(executor);
Expand Down Expand Up @@ -60,7 +64,18 @@ infiniStatus_t Descriptor::create(
std::vector<int64_t> y_strides = {static_cast<int64_t>(info.y_stride_b), static_cast<int64_t>(info.y_stride_i), static_cast<int64_t>(info.y_stride_j)};
y = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, y_strides);
x = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, x_strides);
mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast<int64_t>(info.seq_len), static_cast<int64_t>(info.total_seq_len)}, {static_cast<int64_t>(info.total_seq_len), 1});
mask = new aclnnTensorDescriptor(aclDataType::ACL_BOOL, {static_cast<int64_t>(info.batch_size), static_cast<int64_t>(info.seq_len), static_cast<int64_t>(info.total_seq_len)}, {static_cast<int64_t>(info.seq_len * info.total_seq_len), static_cast<int64_t>(info.total_seq_len), 1});

// Allocate contiguous temp buffer for computation (avoids stride issues)
void *temp_addr = nullptr;
size_t temp_elements = info.batch_size * info.seq_len * info.total_seq_len;
size_t temp_bytes = temp_elements * aclDataTypeSize(toAclDataType(info.dtype));
CHECK_ACL(aclrtMalloc(&temp_addr, temp_bytes, ACL_MEM_MALLOC_HUGE_FIRST));
std::vector<int64_t> temp_strides = {
static_cast<int64_t>(info.seq_len * info.total_seq_len),
static_cast<int64_t>(info.total_seq_len),
1};
auto temp = new aclnnTensorDescriptor(toAclDataType(info.dtype), shape, temp_strides, temp_addr);

// Initialize the value tensor with -∞
if (info.dtype == INFINI_DTYPE_F16) {
Expand All @@ -69,6 +84,12 @@ infiniStatus_t Descriptor::create(
CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE));
value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT16, {}, {});
} else if (info.dtype == INFINI_DTYPE_BF16) {
uint16_t mask_value = 0xff80;
auto size = aclDataTypeSize(aclDataType::ACL_BF16);
CHECK_ACL(aclrtMalloc(&value_addr, size, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(value_addr, size, &mask_value, size, ACL_MEMCPY_HOST_TO_DEVICE));
value = new aclnnTensorDescriptor(aclDataType::ACL_BF16, {}, {});
} else {
uint32_t mask_value = 0xff800000;
auto size = aclDataTypeSize(aclDataType::ACL_FLOAT);
Expand All @@ -77,36 +98,39 @@ infiniStatus_t Descriptor::create(
value = new aclnnTensorDescriptor(aclDataType::ACL_FLOAT, {}, {});
}

// Fill Mask Tensor
std::vector<char> mask_matrix(mask->numel(), 0);
// Fill Mask Tensor (replicate 2D causal mask to all batches)
size_t mask_data_size = info.batch_size * info.seq_len * info.total_seq_len;
std::vector<char> mask_matrix(mask_data_size, 0);
for (size_t i = 0; i < info.seq_len; ++i) {
for (size_t j = info.total_seq_len - info.seq_len + i + 1; j < info.total_seq_len; ++j) {
size_t index = i * info.total_seq_len + j;
mask_matrix[index] = 1;
size_t index_2d = i * info.total_seq_len + j;
for (size_t b = 0; b < info.batch_size; ++b) {
mask_matrix[b * info.seq_len * info.total_seq_len + index_2d] = 1;
}
}
}

auto size = mask->numel() * aclDataTypeSize(aclDataType::ACL_BOOL);
auto size = mask_data_size * aclDataTypeSize(aclDataType::ACL_BOOL);
CHECK_ACL(aclrtMalloc(&mask_addr, size, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(mask_addr, size, mask_matrix.data(), size, ACL_MEMCPY_HOST_TO_DEVICE));

// Get the workspace size for the op
aclTensor *tx = x->tensor;
aclTensor *ttemp = temp->tensor;
aclTensor *ty = y->tensor;
aclTensor *tmask = mask->tensor;
aclTensor *tvalue = value->tensor;

CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor));
CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor));

int64_t dim = 2;
CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(tx, dim, ty, &workspacesize_softmax, &executor));
CHECK_ACL(aclnnSoftmaxGetWorkspaceSize(ttemp, dim, ty, &workspacesize_softmax, &executor));
// set executor reusable
aclSetAclOpExecutorRepeatable(executor);

// Create the descripto
size_t all_workspacesize = std::max(workspacesize_softmax, workspacesize_mask);

*desc_ptr = new Descriptor(new Opaque{x, mask, y, value, mask_addr, value_addr,
*desc_ptr = new Descriptor(new Opaque{x, temp, mask, y, value, mask_addr, value_addr, temp_addr,
workspacesize_softmax, executor},
std::move(info), all_workspacesize, handle_ascend->device, handle_ascend->device_id);

Expand All @@ -117,20 +141,38 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, voi
if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto tx = _opaque->x->tensor;
auto ttemp = _opaque->temp->tensor;
auto ty = _opaque->y->tensor;
auto tmask = _opaque->mask->tensor;
auto tvalue = _opaque->value->tensor;
aclOpExecutor *mask_executor = nullptr;
size_t workspacesize_mask = 0;

AclSetTensorAddr(mask_executor, 0, tx, (void *)x);
// Copy x to contiguous temp buffer (handles custom stride correctly)
size_t dtype_sz = aclDataTypeSize(_opaque->temp->dataType);
size_t row_bytes = _info.total_seq_len * dtype_sz;
for (size_t b = 0; b < _info.batch_size; b++) {
size_t dst_batch_off = b * _info.seq_len * _info.total_seq_len * dtype_sz;
size_t src_batch_off = b * _info.x_stride_b * dtype_sz;
for (size_t i = 0; i < _info.seq_len; i++) {
aclrtMemcpy(
(char *)_opaque->temp_addr + dst_batch_off + i * _info.total_seq_len * dtype_sz,
row_bytes,
(const char *)x + src_batch_off + i * _info.x_stride_i * dtype_sz,
row_bytes,
ACL_MEMCPY_DEVICE_TO_DEVICE);
}
}

// Masked fill on temp (contiguous, no stride issues)
AclSetTensorAddr(mask_executor, 0, ttemp, _opaque->temp_addr);
AclSetTensorAddr(mask_executor, 1, tmask, _opaque->mask_addr);
AclSetTensorAddr(mask_executor, 2, tvalue, _opaque->value_addr);
CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(tx, tmask, tvalue, &workspacesize_mask, &mask_executor));
CHECK_ACL(aclnnInplaceMaskedFillTensorGetWorkspaceSize(ttemp, tmask, tvalue, &workspacesize_mask, &mask_executor));
CHECK_ACL(aclnnInplaceMaskedFillTensor(workspace, workspacesize_mask, mask_executor, stream));

AclSetTensorAddr(_opaque->executor, 0, tx, (void *)x);
// Softmax temp (contiguous) → y
AclSetTensorAddr(_opaque->executor, 0, ttemp, _opaque->temp_addr);
AclSetTensorAddr(_opaque->executor, 1, ty, y);
CHECK_ACL(aclnnSoftmax(workspace, _opaque->workspacesize, _opaque->executor, stream));

Expand Down
Loading