From 9de7f8f6ebe926d13937da301d74ec2ce962533f Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 28 May 2026 03:08:30 +0000 Subject: [PATCH 1/4] feat: add MLASelfAttention Module --- .../transformer/causal_self_attention.h | 7 - .../modules/transformer/mla_self_attention.h | 50 +++++ .../include/nn/modules/transformer/utils.h | 6 + .../transformer/causal_self_attention.cc | 38 +--- .../modules/transformer/mla_self_attention.cc | 185 ++++++++++++++++++ .../src/nn/modules/transformer/utils.cc | 38 ++++ .../test_transformer_architecture.cc | 25 +++ 7 files changed, 305 insertions(+), 44 deletions(-) create mode 100644 infini_train/include/nn/modules/transformer/mla_self_attention.h create mode 100644 infini_train/src/nn/modules/transformer/mla_self_attention.cc diff --git a/infini_train/include/nn/modules/transformer/causal_self_attention.h b/infini_train/include/nn/modules/transformer/causal_self_attention.h index 5ac55e31..7a96714f 100644 --- a/infini_train/include/nn/modules/transformer/causal_self_attention.h +++ b/infini_train/include/nn/modules/transformer/causal_self_attention.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include "infini_train/include/nn/modules/module.h" @@ -43,12 +42,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule> ForwardWithRoPE(const std::vector> &x); - // RoPE helper methods - std::tuple, std::shared_ptr> - ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis); - // GQA helper method std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep); }; diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h new file mode 100644 index 00000000..b4419e43 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { + +class MLASelfAttention : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "MLASelfAttention"; + + static constexpr char kQAProjLayerName[] = "q_a_proj"; + static constexpr char kQANormLayerName[] = "q_a_layernorm"; + static constexpr char kQBProjLayerName[] = "q_b_proj"; + static constexpr char kKVAProjLayerName[] = "kv_a_proj_with_mqa"; + static constexpr char kKVANormLayerName[] = "kv_a_layernorm"; + static constexpr char kKVBProjLayerName[] = "kv_b_proj"; + static constexpr char kCProjLayerName[] = "c_proj"; + + static constexpr char kParamBiasName[] = "bias"; + + explicit MLASelfAttention(const TransformerConfig &config); + MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim); + + std::vector> + Forward(const std::vector> &x) override; + +private: + TransformerConfig config_; + int64_t n_head_ = 0; + int64_t n_embd_ = 0; + int64_t local_n_head_ = 0; + + int64_t q_lora_rank_ = 0; + int64_t kv_lora_rank_ = 0; + int64_t qk_nope_head_dim_ = 0; + int64_t qk_rope_head_dim_ = 0; + int64_t qk_head_dim_ = 0; + int64_t v_head_dim_ = 0; + + void SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim); + +}; + +} // namespace infini_train::nn diff --git a/infini_train/include/nn/modules/transformer/utils.h b/infini_train/include/nn/modules/transformer/utils.h index d3a62c63..30db08e6 100644 --- a/infini_train/include/nn/modules/transformer/utils.h +++ b/infini_train/include/nn/modules/transformer/utils.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include "infini_train/include/tensor.h" @@ -8,4 +10,8 @@ namespace infini_train { // RoPE helper method std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, Device device = Device()); + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis); } // namespace infini_train diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 5ea9eec5..7320ca12 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -12,6 +12,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -130,43 +131,6 @@ CausalSelfAttention::ForwardStandard(const std::vector, std::shared_ptr> -CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis) { - // Reshape freqs_cis for broadcasting - const auto &x_shape = xq->Dims(); // (B, T, H, D) - const int64_t T = x_shape[1]; - const int64_t D = x_shape[3]; - - std::vector target_shape = {1, T, 1, D / 2, 2}; - auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) - - auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) - auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) - - auto slice_pair = [](const std::shared_ptr &x) { - auto even = x->Slice(-1, 0, x->Dims().back(), 2); - auto odd = x->Slice(-1, 1, x->Dims().back(), 2); - return std::make_pair(even, odd); - }; - - auto [q_even, q_odd] = slice_pair(xq); - auto q_rotated_left = q_even * cos - q_odd * sin; - auto q_rotated_right = q_even * sin + q_odd * cos; - auto q_rotated - = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); - - auto [k_even, k_odd] = slice_pair(xk); - auto k_rotated_left = k_even * cos - k_odd * sin; - auto k_rotated_right = k_even * sin + k_odd * cos; - auto k_rotated - = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); - - return {q_rotated, k_rotated}; -} - std::shared_ptr CausalSelfAttention::RepeatKV(const std::shared_ptr &x, int64_t n_rep) { const auto &shape = x->Dims(); diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc new file mode 100644 index 00000000..097cf830 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -0,0 +1,185 @@ +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn { +namespace { +int64_t DefaultQKVHeadDim(const TransformerConfig &config) { + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + return config.n_embd / config.n_head; +} + +int64_t DefaultQKRoPEHeadDim(const TransformerConfig &config) { + return DefaultQKVHeadDim(config); +} + +int64_t DefaultQKNoPEHeadDim(const TransformerConfig &config) { + return DefaultQKVHeadDim(config); +} +} // namespace + +MLASelfAttention::MLASelfAttention(const TransformerConfig &config) + : MLASelfAttention(config, + /*q_lora_rank=*/config.n_embd, + /*kv_lora_rank=*/config.n_embd, + /*qk_nope_head_dim=*/DefaultQKNoPEHeadDim(config), + /*qk_rope_head_dim=*/DefaultQKRoPEHeadDim(config), + /*v_head_dim=*/DefaultQKVHeadDim(config)) {} + +MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) + : CloneableModule(kType), config_(config) { + SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim); + + modules_[kQAProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear); + modules_[kQANormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); + modules_[kQBProjLayerName] = std::make_shared( + /*in_features=*/q_lora_rank_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + modules_[kKVAProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear); + modules_[kKVANormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); + modules_[kKVBProjLayerName] = std::make_shared( + /*in_features=*/kv_lora_rank_, + /*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_), + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + modules_[kCProjLayerName] = std::make_shared( + /*in_features=*/n_head_ * v_head_dim_, + /*out_features=*/n_embd_, + /*bias=*/config_.add_bias_linear, + /*reduce_output=*/true, + /*input_is_parallel=*/true, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); +} + +void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) { + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; + CHECK_GT(q_lora_rank, 0) << "q_lora_rank must be positive"; + CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; + CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; + CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; + CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive"; + CHECK_EQ(qk_rope_head_dim % 2, 0) << "qk_rope_head_dim must be even for RoPE"; + + n_head_ = config.n_head; + n_embd_ = config.n_embd; + local_n_head_ = n_head_ / tp_world_size; + + q_lora_rank_ = q_lora_rank; + kv_lora_rank_ = kv_lora_rank; + qk_nope_head_dim_ = qk_nope_head_dim; + qk_rope_head_dim_ = qk_rope_head_dim; + qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; + v_head_dim_ = v_head_dim; +} + +std::vector> +MLASelfAttention::Forward(const std::vector> &x) { + CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states"; + + const auto B = x[0]->Dims()[0]; + const auto C = x[0]->Dims()[2]; + CHECK_EQ(C, n_embd_) << "hidden size must match n_embd"; + + const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + const auto external_mask = x.size() > 3 ? x[3] : nullptr; + if (config_.attention_type == AttentionType::kRoPE) { + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + } + + // (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope)) + auto q = (*modules_[kQAProjLayerName])({x[0]})[0]; + q = (*modules_[kQANormLayerName])({q})[0]; + q = (*modules_[kQBProjLayerName])({q})[0]; + const auto T = q->Dims()[1]; + q = q->View({B, T, local_n_head_, qk_head_dim_}); + + auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_); + auto q_pe = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + + // (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key. + auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName])({x[0]})[0]; + auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_); + auto k_pe = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_) + ->Contiguous(); + if (nn::parallel::global::GetSequenceParallelEnabled()) { + k_pe = nn::parallel::GatherFromSPRegionFunc(k_pe)[0]; + } + k_pe = k_pe->View({B, T, 1, qk_rope_head_dim_}); + + // (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v)) + auto kv = (*modules_[kKVANormLayerName])({compressed_kv})[0]; + kv = (*modules_[kKVBProjLayerName])({kv})[0]; + kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_}); + auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_); + auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); + + if (config_.attention_type == AttentionType::kRoPE) { + std::tie(q_pe, k_pe) = ApplyRotaryEmbedding(q_pe, k_pe, freqs_cis); + } + + k_pe = k_pe->RepeatInterleave(local_n_head_, 2); + q = nn::function::Concat(std::vector>{q_nope, q_pe}, -1); + auto k = nn::function::Concat(std::vector>{k_nope, k_pe}, -1); + + // (B, T, H_local, D) -> (B, H_local, T, D) + q = q->Transpose(1, 2); + k = k->Transpose(1, 2); + v = v->Transpose(1, 2); + + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(qk_head_dim_))); + if (external_mask) { + att = att->MaskedFill(external_mask, std::numeric_limits::lowest()); + } else { + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + } + att = nn::function::Softmax(att, -1); + + auto y = att->Matmul(v); + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); + y = (*modules_[kCProjLayerName])({y})[0]; + return {y}; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/transformer/utils.cc b/infini_train/src/nn/modules/transformer/utils.cc index 98505fd0..4ec11f2d 100644 --- a/infini_train/src/nn/modules/transformer/utils.cc +++ b/infini_train/src/nn/modules/transformer/utils.cc @@ -1,5 +1,9 @@ #include "infini_train/include/nn/modules/transformer/utils.h" +#include +#include +#include + #include "glog/logging.h" #include "infini_train/include/nn/functional.h" @@ -27,4 +31,38 @@ std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta return freqs_cis; } + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis) { + const auto &x_shape = xq->Dims(); // (B, T, H, D) + const int64_t T = x_shape[1]; + const int64_t D = x_shape[3]; + + std::vector target_shape = {1, T, 1, D / 2, 2}; + auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) + + auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) + auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) + + auto slice_pair = [](const std::shared_ptr &x) { + auto even = x->Slice(-1, 0, x->Dims().back(), 2); + auto odd = x->Slice(-1, 1, x->Dims().back(), 2); + return std::make_pair(even, odd); + }; + + auto [q_even, q_odd] = slice_pair(xq); + auto q_rotated_left = q_even * cos - q_odd * sin; + auto q_rotated_right = q_even * sin + q_odd * cos; + auto q_rotated + = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); + + auto [k_even, k_odd] = slice_pair(xk); + auto k_rotated_left = k_even * cos - k_odd * sin; + auto k_rotated_right = k_even * sin + k_odd * cos; + auto k_rotated + = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); + + return {q_rotated, k_rotated}; +} } // namespace infini_train diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..f36d10f6 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -7,6 +7,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" @@ -110,6 +111,30 @@ TEST_P(TransformerModuleTest, StandardAttention) { EXPECT_EQ(output[0]->Dims(), input->Dims()); } +TEST_P(TransformerModuleTest, MLAAttention) { + SKIP_CPU(); + nn::TransformerConfig config; + config.n_embd = 64; + config.n_head = 4; + config.block_size = 16; + config.attention_type = nn::AttentionType::kStandard; + config.add_bias_linear = true; + + auto attn = std::make_shared( + config, + /*q_lora_rank=*/32, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16); + attn->To(GetDevice()); + EXPECT_FALSE(attn->Parameters().empty()); + + auto input = std::make_shared(std::vector{2, 8, 64}, DataType::kFLOAT32, GetDevice()); + auto output = (*attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); +} + TEST_P(TransformerModuleTest, GPT2TransformerLayer) { SKIP_CPU(); nn::TransformerConfig config; From 87ca357154f544aa783bbe533e451027ceb446f6 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 28 May 2026 13:33:10 +0000 Subject: [PATCH 2/4] feat: support q_lora/non-q_lora and tp/non-tp variations --- .../modules/transformer/mla_self_attention.h | 26 ++- infini_train/include/nn/parallel/utils.h | 1 + .../modules/transformer/mla_self_attention.cc | 201 ++++++++++++++---- .../src/nn/parallel/tensor_parallel.cc | 37 ++++ .../test_transformer_architecture.cc | 34 +++ 5 files changed, 242 insertions(+), 57 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h index b4419e43..75b9da3a 100644 --- a/infini_train/include/nn/modules/transformer/mla_self_attention.h +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -12,19 +12,21 @@ class MLASelfAttention : public infini_train::nn::CloneableModule> Forward(const std::vector> &x) override; @@ -42,9 +44,13 @@ class MLASelfAttention : public infini_train::nn::CloneableModule GetPipelineParallelGroupRanks(int global_rank); // TP/SP Communication Helper Functions std::vector> GatherFromTPRegionFunc(const std::shared_ptr &input); +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> ReduceScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input); std::vector> ScatterToTPRegionFunc(const std::shared_ptr &input); diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc index 097cf830..423c91c5 100644 --- a/infini_train/src/nn/modules/transformer/mla_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -43,30 +43,65 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config) /*v_head_dim=*/DefaultQKVHeadDim(config)) {} MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, + bool q_down_proj_use_tp, bool kv_down_proj_use_tp) : CloneableModule(kType), config_(config) { - SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim); - - modules_[kQAProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/q_lora_rank_, - /*bias=*/config_.add_bias_linear); - modules_[kQANormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); - modules_[kQBProjLayerName] = std::make_shared( - /*in_features=*/q_lora_rank_, - /*out_features=*/n_head_ * qk_head_dim_, - /*bias=*/config_.add_bias_linear, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + q_down_proj_use_tp, kv_down_proj_use_tp); + + if (use_q_lora_) { + if (q_down_proj_use_tp_) { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear); + } + modules_[kQLayerNormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); + modules_[kLinearQUpProjLayerName] = std::make_shared( + /*in_features=*/q_lora_rank_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } - modules_[kKVAProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, - /*bias=*/config_.add_bias_linear); - modules_[kKVANormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); - modules_[kKVBProjLayerName] = std::make_shared( + if (kv_down_proj_use_tp_) { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear); + } + modules_[kKVLayerNormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); + modules_[kLinearKVUpProjLayerName] = std::make_shared( /*in_features=*/kv_lora_rank_, /*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_), /*bias=*/config_.add_bias_linear, @@ -75,7 +110,7 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo /*skip_bias_add=*/false, /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - modules_[kCProjLayerName] = std::make_shared( + modules_[kLinearProjLayerName] = std::make_shared( /*in_features=*/n_head_ * v_head_dim_, /*out_features=*/n_embd_, /*bias=*/config_.add_bias_linear, @@ -89,12 +124,13 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo } void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) { + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, + bool q_down_proj_use_tp, bool kv_down_proj_use_tp) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; - CHECK_GT(q_lora_rank, 0) << "q_lora_rank must be positive"; + CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA"; CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; @@ -105,80 +141,151 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q n_embd_ = config.n_embd; local_n_head_ = n_head_ / tp_world_size; - q_lora_rank_ = q_lora_rank; + use_q_lora_ = q_lora_rank != -1; + q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0; kv_lora_rank_ = kv_lora_rank; qk_nope_head_dim_ = qk_nope_head_dim; qk_rope_head_dim_ = qk_rope_head_dim; qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; v_head_dim_ = v_head_dim; + q_down_proj_use_tp_ = q_down_proj_use_tp; + kv_down_proj_use_tp_ = kv_down_proj_use_tp; } std::vector> MLASelfAttention::Forward(const std::vector> &x) { CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states"; + // x[0]: (B, T_local, C) const auto B = x[0]->Dims()[0]; const auto C = x[0]->Dims()[2]; CHECK_EQ(C, n_embd_) << "hidden size must match n_embd"; + // freqs_cis: (T, D_rope / 2, 2) const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + // external_mask: (1, 1, T, T) const auto external_mask = x.size() > 3 ? x[3] : nullptr; if (config_.attention_type == AttentionType::kRoPE) { CHECK(freqs_cis != nullptr) << "freqs_cis is null."; } - // (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope)) - auto q = (*modules_[kQAProjLayerName])({x[0]})[0]; - q = (*modules_[kQANormLayerName])({q})[0]; - q = (*modules_[kQBProjLayerName])({q})[0]; + const bool sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + + // ----------- Q PATH ----------- + // Q path, align with Megatron: + // - q_lora_rank == -1 -> linear_q_proj directly; + // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj. + std::shared_ptr q; + if (use_q_lora_) { + // linear_q_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_q) + // TP path before gather: (B, T, C) -> (B, T, R_q / TP) + // - Note that ColumnParallelLinear would perform a GatherFromSPRegion in the beginning + auto q_compressed = (*modules_[kLinearQDownProjLayerName])({x[0]})[0]; + if (q_down_proj_use_tp_ && q_compressed->Dims().back() != q_lora_rank_) { + // Gather the sharded latent dimension: (B, T, R_q / TP) -> (B, T, R_q). + q_compressed = nn::parallel::GatherFromTPRegionFunc(q_compressed)[0]; + if (sequence_parallel_enabled) { + // Keep the q_up input sequence-sharded: (B, T_full, R_q) -> (B, T_local, R_q). + q_compressed = nn::parallel::ScatterToSPRegionFunc(q_compressed)[0]; + } + } + // q_layernorm preserves shape: (B, T_local, R_q) + q_compressed = (*modules_[kQLayerNormLayerName])({q_compressed})[0]; + // linear_q_up_proj: (B, T_local, R_q) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQUpProjLayerName])({q_compressed})[0]; + } else { + // linear_q_proj direct path: (B, T, C) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQProjLayerName])({x[0]})[0]; + } + + // T should be the full seqlen after the q projection path gathers sequence-parallel input. const auto T = q->Dims()[1]; + // q: (B, T, H_local * D_qk) -> (B, T, H_local, D_qk) + // qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_ q = q->View({B, T, local_n_head_, qk_head_dim_}); + // q_nope: (B, T, H_local, D_nope), q_pos_emb: (B, T, H_local, D_rope) auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_); - auto q_pe = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + auto q_pos_emb = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + + // ----------- KV PATH ----------- + // linear_kv_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_kv + D_rope) + // TP path before gather: (B, T, C) -> (B, T, (R_kv + D_rope) / TP) + auto compressed_kv_with_pe = (*modules_[kLinearKVDownProjLayerName])({x[0]})[0]; + const auto kv_down_proj_out_dim = kv_lora_rank_ + qk_rope_head_dim_; + const bool kv_down_proj_output_is_sharded = compressed_kv_with_pe->Dims().back() != kv_down_proj_out_dim; + if (kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded) { + // Gather latent+RoPE dim: (B, T, (R_kv + D_rope) / TP) -> (B, T, R_kv + D_rope) + compressed_kv_with_pe = nn::parallel::GatherFromTPRegionFunc(compressed_kv_with_pe)[0]; + } - // (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key. - auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName])({x[0]})[0]; + // compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope) auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_); - auto k_pe = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_) - ->Contiguous(); - if (nn::parallel::global::GetSequenceParallelEnabled()) { - k_pe = nn::parallel::GatherFromSPRegionFunc(k_pe)[0]; + auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous(); + const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded + && sequence_parallel_enabled; + if (k_pos_emb_has_full_sequence) { + // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj. + // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv) + compressed_kv = nn::parallel::ScatterToSPRegionFunc(compressed_kv)[0]; + } else if (sequence_parallel_enabled) { + // Replicated down-proj path produces local k_pos_emb; gather it for attention. + // k_pos_emb: (B, T_local, D_rope) -> (B, T, D_rope) + k_pos_emb = nn::parallel::GatherFromSPRegionFunc(k_pos_emb)[0]; } - k_pe = k_pe->View({B, T, 1, qk_rope_head_dim_}); + // k_pos_emb: (B, T, D_rope) -> (B, T, 1, D_rope), shared across local heads. + k_pos_emb = k_pos_emb->View({B, T, 1, qk_rope_head_dim_}); - // (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v)) - auto kv = (*modules_[kKVANormLayerName])({compressed_kv})[0]; - kv = (*modules_[kKVBProjLayerName])({kv})[0]; + // (B, T, R_kv) -> kv_layernorm -> linear_kv_up_proj -> (B, T, H_local * (D_nope + D_v)) + // kv_layernorm preserves compressed_kv shape: (B, T_local, R_kv) + auto kv = (*modules_[kKVLayerNormLayerName])({compressed_kv})[0]; + // linear_kv_up_proj: (B, T_local, R_kv) -> (B, T, H_local * (D_nope + D_v)) + kv = (*modules_[kLinearKVUpProjLayerName])({kv})[0]; + // kv: (B, T, H_local * (D_nope + D_v)) -> (B, T, H_local, D_nope + D_v) kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_}); + // k_nope: (B, T, H_local, D_nope), v: (B, T, H_local, D_v) auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_); auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); if (config_.attention_type == AttentionType::kRoPE) { - std::tie(q_pe, k_pe) = ApplyRotaryEmbedding(q_pe, k_pe, freqs_cis); + // q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope) + std::tie(q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding(q_pos_emb, k_pos_emb, freqs_cis); } - k_pe = k_pe->RepeatInterleave(local_n_head_, 2); - q = nn::function::Concat(std::vector>{q_nope, q_pe}, -1); - auto k = nn::function::Concat(std::vector>{k_nope, k_pe}, -1); + // k_pos_emb: (B, T, 1, D_rope) -> (B, T, H_local, D_rope) + k_pos_emb = k_pos_emb->RepeatInterleave(local_n_head_, 2); + // q: (B, T, H_local, D_qk), k: (B, T, H_local, D_qk) + q = nn::function::Concat(std::vector>{q_nope, q_pos_emb}, -1); + auto k = nn::function::Concat(std::vector>{k_nope, k_pos_emb}, -1); - // (B, T, H_local, D) -> (B, H_local, T, D) + // ----------- CORE ATTN ----------- + // q/k: (B, T, H_local, D_qk) -> (B, H_local, T, D_qk) + // v: (B, T, H_local, D_v) -> (B, H_local, T, D_v) q = q->Transpose(1, 2); k = k->Transpose(1, 2); v = v->Transpose(1, 2); + // att: (B, H_local, T, T) auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(qk_head_dim_))); if (external_mask) { att = att->MaskedFill(external_mask, std::numeric_limits::lowest()); } else { + // mask: (1, 1, T, T) auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); } + // att: (B, H_local, T, T) att = nn::function::Softmax(att, -1); + // y: (B, H_local, T, D_v) auto y = att->Matmul(v); + // y: (B, H_local, T, D_v) -> (B, T, H_local, D_v) -> (B, T, H_local * D_v) y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); - y = (*modules_[kCProjLayerName])({y})[0]; + // linear_proj: (B, T, H_local * D_v) -> (B, T, C) + y = (*modules_[kLinearProjLayerName])({y})[0]; + return {y}; } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 44ab8189..b83c5e52 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -45,6 +45,24 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso return gathered_output; } +std::shared_ptr ScatterAlongFirstDim(const std::shared_ptr &tensor) { + int world_size = global::GetTensorParallelSize(); + CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; + if (world_size == 1) { + return tensor; + } + + auto device = tensor->GetDevice(); + auto tp_group = ProcessGroupFactory::Instance(device.type()) + ->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + auto rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); + + CHECK_EQ(tensor->Dims()[0] % world_size, 0) << "First dimension must be divisible by TP world size"; + auto first_dim_size = tensor->Dims()[0] / world_size; + auto shards = tensor->Split(first_dim_size, 0); + return shards[rank]->Contiguous(); +} + std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor) { int world_size = global::GetTensorParallelSize(); CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; @@ -214,6 +232,21 @@ class ReduceScatterToSPRegion : public autograd::Function { }; }; +class ScatterToSPRegion : public autograd::Function { +public: + static constexpr char kType[] = "ScatterToSPRegionFunction"; + + explicit ScatterToSPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {ScatterAlongFirstDim(input_tensors[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {GatherAlongFirstDim(grad_outputs[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; +}; + class GatherFromSPRegion : public autograd::Function { public: static constexpr char kType[] = "GatherFromSPRegionFunction"; @@ -263,6 +296,10 @@ std::vector> ReduceScatterToSPRegionFunc(const std::shar return std::make_shared()->Apply({input}); } +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input}); +} + std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input) { return std::make_shared()->Apply({input}); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index f36d10f6..047566ea 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" +#include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" @@ -12,6 +13,7 @@ #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" #include "tests/common/test_utils.h" @@ -129,10 +131,42 @@ TEST_P(TransformerModuleTest, MLAAttention) { /*v_head_dim=*/16); attn->To(GetDevice()); EXPECT_FALSE(attn->Parameters().empty()); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), nn::Linear::kType); auto input = std::make_shared(std::vector{2, 8, 64}, DataType::kFLOAT32, GetDevice()); auto output = (*attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto tp_down_attn = std::make_shared( + config, + /*q_lora_rank=*/32, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16, + /*q_down_proj_use_tp=*/true, + /*kv_down_proj_use_tp=*/true); + tp_down_attn->To(GetDevice()); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*tp_down_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto direct_q_attn = std::make_shared( + config, + /*q_lora_rank=*/-1, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16); + direct_q_attn->To(GetDevice()); + EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*direct_q_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); } TEST_P(TransformerModuleTest, GPT2TransformerLayer) { From dd18b35eeb8f9a80b34b0d95b9f8f06c4447df51 Mon Sep 17 00:00:00 2001 From: bolunz Date: Fri, 29 May 2026 02:36:49 +0000 Subject: [PATCH 3/4] fix: move mla args into TransformerConfig --- .../modules/transformer/mla_self_attention.h | 7 +-- .../modules/transformer/transformer_config.h | 10 ++++ .../modules/transformer/mla_self_attention.cc | 60 ++++++------------- .../src/nn/modules/transformer/transformer.cc | 17 ++++-- .../test_transformer_architecture.cc | 40 +++++-------- 5 files changed, 58 insertions(+), 76 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h index 75b9da3a..63177cc6 100644 --- a/infini_train/include/nn/modules/transformer/mla_self_attention.h +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -24,9 +24,6 @@ class MLASelfAttention : public infini_train::nn::CloneableModule> Forward(const std::vector> &x) override; @@ -48,9 +45,7 @@ class MLASelfAttention : public infini_train::nn::CloneableModule q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path. + int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention. + int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head. + bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj. + bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj. + // Normalization float norm_eps = 1e-5f; // epsilon in RMSNorm diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc index 423c91c5..7549e812 100644 --- a/infini_train/src/nn/modules/transformer/mla_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -19,35 +19,9 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -namespace { -int64_t DefaultQKVHeadDim(const TransformerConfig &config) { - CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; - return config.n_embd / config.n_head; -} - -int64_t DefaultQKRoPEHeadDim(const TransformerConfig &config) { - return DefaultQKVHeadDim(config); -} -int64_t DefaultQKNoPEHeadDim(const TransformerConfig &config) { - return DefaultQKVHeadDim(config); -} -} // namespace - -MLASelfAttention::MLASelfAttention(const TransformerConfig &config) - : MLASelfAttention(config, - /*q_lora_rank=*/config.n_embd, - /*kv_lora_rank=*/config.n_embd, - /*qk_nope_head_dim=*/DefaultQKNoPEHeadDim(config), - /*qk_rope_head_dim=*/DefaultQKRoPEHeadDim(config), - /*v_head_dim=*/DefaultQKVHeadDim(config)) {} - -MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, - bool q_down_proj_use_tp, bool kv_down_proj_use_tp) - : CloneableModule(kType), config_(config) { - SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, - q_down_proj_use_tp, kv_down_proj_use_tp); +MLASelfAttention::MLASelfAttention(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + SetupAttention(config); if (use_q_lora_) { if (q_down_proj_use_tp_) { @@ -123,15 +97,19 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo ->View({1, 1, config_.block_size, config_.block_size}); } -void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, - bool q_down_proj_use_tp, bool kv_down_proj_use_tp) { +void MLASelfAttention::SetupAttention(const TransformerConfig &config) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; - CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA"; - CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; + CHECK(!config.q_lora_rank.has_value() || config.q_lora_rank.value() > 0) << "q_lora_rank must be positive when set"; + + const auto default_head_dim = config.n_embd / config.n_head; + const int64_t kv_lora_rank = config.kv_lora_rank > 0 ? config.kv_lora_rank : config.n_embd; + const int64_t qk_nope_head_dim = config.qk_nope_head_dim > 0 ? config.qk_nope_head_dim : default_head_dim; + const int64_t qk_rope_head_dim = config.qk_rope_head_dim > 0 ? config.qk_rope_head_dim : default_head_dim; + const int64_t v_head_dim = config.v_head_dim > 0 ? config.v_head_dim : default_head_dim; + CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive"; @@ -141,15 +119,15 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q n_embd_ = config.n_embd; local_n_head_ = n_head_ / tp_world_size; - use_q_lora_ = q_lora_rank != -1; - q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0; + use_q_lora_ = config.q_lora_rank.has_value(); + q_lora_rank_ = config.q_lora_rank.value_or(0); kv_lora_rank_ = kv_lora_rank; qk_nope_head_dim_ = qk_nope_head_dim; qk_rope_head_dim_ = qk_rope_head_dim; qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; v_head_dim_ = v_head_dim; - q_down_proj_use_tp_ = q_down_proj_use_tp; - kv_down_proj_use_tp_ = kv_down_proj_use_tp; + q_down_proj_use_tp_ = config.q_down_proj_use_tp; + kv_down_proj_use_tp_ = config.kv_down_proj_use_tp; } std::vector> @@ -173,7 +151,7 @@ MLASelfAttention::Forward(const std::vector linear_q_proj directly; + // - q_lora_rank == nullopt -> linear_q_proj directly; // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj. std::shared_ptr q; if (use_q_lora_) { @@ -224,8 +202,8 @@ MLASelfAttention::Forward(const std::vectorSlice(-1, 0, kv_lora_rank_); auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous(); - const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded - && sequence_parallel_enabled; + const bool k_pos_emb_has_full_sequence + = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded && sequence_parallel_enabled; if (k_pos_emb_has_full_sequence) { // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj. // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv) @@ -285,7 +263,7 @@ MLASelfAttention::Forward(const std::vectorTranspose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); // linear_proj: (B, T, H_local * D_v) -> (B, T, C) y = (*modules_[kLinearProjLayerName])({y})[0]; - + return {y}; } diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..048cf96c 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -14,6 +14,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" @@ -28,8 +29,8 @@ TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config) modules_[kWTELayerName] = std::make_shared( config_.vocab_size, config_.n_embd, parallel::global::GetSequenceParallelEnabled()); - // LLaMA3 use RoPE, so they don't need position embedding - if (config_.activation_type == MLPType::kGELU) { + // RoPE-based models do not use absolute position embedding. + if (config_.attention_type == AttentionType::kStandard) { modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); } } @@ -85,7 +86,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea LOG(FATAL) << "Unsupported norm type"; } - modules_[kAttnLayerName] = std::make_shared(config); + if (config.multi_latent_attention) { + modules_[kAttnLayerName] = std::make_shared(config); + } else { + modules_[kAttnLayerName] = std::make_shared(config); + } modules_[kMlpLayerName] = std::make_shared(config); } @@ -135,8 +140,10 @@ std::vector> TransformerChunk::Forward(const std::vector // Init freqs_cis on device only once if (buffers_[kFreqsCisName] == nullptr) { - int64_t head_dim = config_.n_embd / config_.n_head; - buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta, + int64_t rope_head_dim = config_.multi_latent_attention && config_.qk_rope_head_dim > 0 + ? config_.qk_rope_head_dim + : config_.n_embd / config_.n_head; + buffers_[kFreqsCisName] = PrecomputeFreqsCis(rope_head_dim, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, device); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index 047566ea..9ff660a8 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include "gtest/gtest.h" @@ -121,14 +122,14 @@ TEST_P(TransformerModuleTest, MLAAttention) { config.block_size = 16; config.attention_type = nn::AttentionType::kStandard; config.add_bias_linear = true; - - auto attn = std::make_shared( - config, - /*q_lora_rank=*/32, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16); + config.multi_latent_attention = true; + config.q_lora_rank = 32; + config.kv_lora_rank = 32; + config.qk_nope_head_dim = 8; + config.qk_rope_head_dim = 8; + config.v_head_dim = 16; + + auto attn = std::make_shared(config); attn->To(GetDevice()); EXPECT_FALSE(attn->Parameters().empty()); EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType); @@ -138,15 +139,10 @@ TEST_P(TransformerModuleTest, MLAAttention) { auto output = (*attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); - auto tp_down_attn = std::make_shared( - config, - /*q_lora_rank=*/32, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16, - /*q_down_proj_use_tp=*/true, - /*kv_down_proj_use_tp=*/true); + auto tp_down_config = config; + tp_down_config.q_down_proj_use_tp = true; + tp_down_config.kv_down_proj_use_tp = true; + auto tp_down_attn = std::make_shared(tp_down_config); tp_down_attn->To(GetDevice()); EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::parallel::ColumnParallelLinear::kType); @@ -155,13 +151,9 @@ TEST_P(TransformerModuleTest, MLAAttention) { output = (*tp_down_attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); - auto direct_q_attn = std::make_shared( - config, - /*q_lora_rank=*/-1, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16); + auto direct_q_config = config; + direct_q_config.q_lora_rank = std::nullopt; + auto direct_q_attn = std::make_shared(direct_q_config); direct_q_attn->To(GetDevice()); EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(), nn::parallel::ColumnParallelLinear::kType); From 937a71cb0ca15d66d491b3dbc6000a930d7316b0 Mon Sep 17 00:00:00 2001 From: bolunz Date: Tue, 2 Jun 2026 07:26:39 +0000 Subject: [PATCH 4/4] refactor: merge 2 MHA paths, rename attention_type to position_embedding_type --- example/gpt2/config.h | 5 +- example/llama3/config.h | 4 +- .../transformer/causal_self_attention.h | 8 -- .../modules/transformer/transformer_config.h | 16 ++- .../transformer/causal_self_attention.cc | 109 +++++------------- .../modules/transformer/mla_self_attention.cc | 4 +- .../src/nn/modules/transformer/transformer.cc | 22 ++-- .../test_transformer_architecture.cc | 8 +- 8 files changed, 62 insertions(+), 114 deletions(-) diff --git a/example/gpt2/config.h b/example/gpt2/config.h index 078f9fd5..71cc0a56 100644 --- a/example/gpt2/config.h +++ b/example/gpt2/config.h @@ -14,7 +14,7 @@ inline nn::TransformerConfig GPT2Config() { .n_head = 12, .n_kv_head = 12, .n_embd = 768, - .attention_type = nn::AttentionType::kStandard, + .position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute, .activation_type = nn::MLPType::kGELU, .norm_type = nn::NormType::kLayerNorm, .add_bias_linear = true, @@ -34,7 +34,8 @@ inline void SanitizeGPT2Config(const nn::TransformerConfig &c) { CHECK_GT(c.n_embd, 0); CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(c.n_kv_head, c.n_head) << "GPT-2 does not use GQA; n_kv_head must equal n_head"; - CHECK(c.attention_type == nn::AttentionType::kStandard) << "GPT-2 requires standard attention"; + CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kLearnedAbsolute) + << "GPT-2 requires learned absolute position embedding"; CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation"; CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm"; } diff --git a/example/llama3/config.h b/example/llama3/config.h index 6bc9124d..67ebd31f 100644 --- a/example/llama3/config.h +++ b/example/llama3/config.h @@ -14,7 +14,7 @@ inline nn::TransformerConfig LLaMA3Config() { .n_head = 32, .n_kv_head = 8, .n_embd = 2048, - .attention_type = nn::AttentionType::kRoPE, + .position_embedding_type = nn::PositionEmbeddingType::kRoPE, .activation_type = nn::MLPType::kSwiGLU, .norm_type = nn::NormType::kRMSNorm, .add_bias_linear = false, @@ -36,7 +36,7 @@ inline void SanitizeLLaMA3Config(const nn::TransformerConfig &c) { CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; CHECK_GT(c.n_embd, 0); CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; - CHECK(c.attention_type == nn::AttentionType::kRoPE) << "LLaMA-3 requires RoPE attention"; + CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kRoPE) << "LLaMA-3 requires RoPE position embedding"; CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "LLaMA-3 requires SwiGLU activation"; CHECK(c.norm_type == nn::NormType::kRMSNorm) << "LLaMA-3 requires RMSNorm"; CHECK(!c.add_bias_linear) << "LLaMA-3 has no bias in linear layers"; diff --git a/infini_train/include/nn/modules/transformer/causal_self_attention.h b/infini_train/include/nn/modules/transformer/causal_self_attention.h index 7a96714f..60373414 100644 --- a/infini_train/include/nn/modules/transformer/causal_self_attention.h +++ b/infini_train/include/nn/modules/transformer/causal_self_attention.h @@ -34,14 +34,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule> - ForwardStandard(const std::vector> &x); - - // RoPE-aware attention forward (LLaMA3 style: with RoPE, optional GQA) - std::vector> - ForwardWithRoPE(const std::vector> &x); - // GQA helper method std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep); }; diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index e374c180..5a440e60 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -10,9 +10,13 @@ enum class ModelType { kLLaMA3, // LLaMA3 }; -enum class AttentionType { - kStandard, // Standard attention - kRoPE // Rotary Position Embedding +enum class PositionEmbeddingType { + kLearnedAbsolute, // Megatron: learned_absolute + kRoPE, // Megatron: rope + kYarn, // Megatron: yarn + kMRoPE, // Megatron: mrope + kRelative, // Megatron: relative + kNone // Megatron: none }; enum class MLPType { @@ -34,9 +38,9 @@ struct TransformerConfig { int64_t n_kv_head = 12; // Num of Key/Value heads (<= n_head, < n_head if using GQA) int64_t n_embd = 768; // Hidden size - AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type - MLPType activation_type = MLPType::kGELU; // MLP activation type - NormType norm_type = NormType::kLayerNorm; // Normalization type + PositionEmbeddingType position_embedding_type = PositionEmbeddingType::kLearnedAbsolute; // Position embedding type. + MLPType activation_type = MLPType::kGELU; // MLP activation type + NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, // including: attention QKV projection, attention output projection, MLP FC layers (and diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 7320ca12..43e4ca51 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -1,6 +1,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include +#include #include #include #include @@ -43,12 +44,9 @@ CausalSelfAttention::CausalSelfAttention(const TransformerConfig &config) : Clon /*skip_bias_add=*/false, /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - // For standard attention (GPT2 style), precompute causal mask - if (config_.attention_type == AttentionType::kStandard) { - // causal mask: (1, 1, block_size, block_size) - buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) - ->View({1, 1, config_.block_size, config_.block_size}); - } + // causal mask: (1, 1, block_size, block_size) + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); } void CausalSelfAttention::SetupAttention(const TransformerConfig &config) { @@ -77,88 +75,21 @@ void CausalSelfAttention::SetupAttention(const TransformerConfig &config) { std::vector> CausalSelfAttention::Forward(const std::vector> &x) { - if (config_.attention_type == AttentionType::kRoPE) { - return ForwardWithRoPE(x); - } else { - return ForwardStandard(x); - } -} - -std::vector> -CausalSelfAttention::ForwardStandard(const std::vector> &x) { - auto tp_world_size = parallel::global::GetTensorParallelSize(); - - const auto B = x[0]->Dims()[0]; // bs - const auto C = x[0]->Dims()[2]; // n_embd - const int64_t head_dim = n_embd_ / n_head_; // per-head dim (global) - const int64_t local_C = n_embd_ / tp_world_size; // per-rank hidden - - // (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C) - // -> Split -> (3, B, T, local_C) - auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2); - - // (B, T, local_C) - auto q = qkv[0]; - auto k = qkv[1]; - auto v = qkv[2]; - - // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear - const auto T = q->Dims()[1]; - - // View to multi-head: local_n_head * head_dim == local_C - // (B, T, local_C) -> (B, T, h_l, Dh) -> (B, h_l, T, Dh) - k = k->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); - // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); - - // Get full tensor - // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) - y = (*modules_[kCProjLayerName])({y})[0]; - // (B, T, C) == (bs, seq_len, n_embd) - return {y}; -} - -std::shared_ptr CausalSelfAttention::RepeatKV(const std::shared_ptr &x, - int64_t n_rep) { - const auto &shape = x->Dims(); - const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; - - if (n_rep == 1) { - return x; - } - - return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); -} - -std::vector> -CausalSelfAttention::ForwardWithRoPE(const std::vector> &x) { const auto B = x[0]->Dims()[0]; // bs const auto C = x[0]->Dims()[2]; // n_embd const auto tp_size = nn::parallel::global::GetTensorParallelSize(); const auto C_local = C / tp_size; - const auto H_local = n_head_ / tp_size; + const auto H_local = local_n_head_; const auto KV_local = n_kv_head_ / tp_size; const auto D = head_dim_; // n_embd / n_head const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; - const auto start_pos = x.size() > 2 ? x[2] : nullptr; const auto mask = x.size() > 3 ? x[3] : nullptr; - CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + } // (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D) auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0]; @@ -176,10 +107,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vectorSlice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D}); - // -> RoPE on q, k - // q: (B, T, H_local, D) - // k: (B, T, KV_local, D) - std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + // q: (B, T, H_local, D), k: (B, T, KV_local, D) + std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); + } // TODO(zbl): use kv cache during inference // if (use_kv_) { ... } @@ -207,6 +138,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vectorMaskedFill(mask, std::numeric_limits::lowest()); + } else { + // fallback causal mask: (1, 1, T, T) + auto causal_mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + att = att->MaskedFill(causal_mask == 0, -std::numeric_limits::infinity()); } // (B, H_local, T, T) att = nn::function::Softmax(att, -1); @@ -221,4 +156,16 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector CausalSelfAttention::RepeatKV(const std::shared_ptr &x, + int64_t n_rep) { + const auto &shape = x->Dims(); + const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; + + if (n_rep == 1) { + return x; + } + + return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); +} + } // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc index 7549e812..536d077d 100644 --- a/infini_train/src/nn/modules/transformer/mla_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -143,7 +143,7 @@ MLASelfAttention::Forward(const std::vector 1 ? x[1] : nullptr; // external_mask: (1, 1, T, T) const auto external_mask = x.size() > 3 ? x[3] : nullptr; - if (config_.attention_type == AttentionType::kRoPE) { + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { CHECK(freqs_cis != nullptr) << "freqs_cis is null."; } @@ -227,7 +227,7 @@ MLASelfAttention::Forward(const std::vectorSlice(-1, 0, qk_nope_head_dim_); auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); - if (config_.attention_type == AttentionType::kRoPE) { + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { // q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope) std::tie(q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding(q_pos_emb, k_pos_emb, freqs_cis); } diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index 048cf96c..9d3cf35e 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -29,9 +29,11 @@ TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config) modules_[kWTELayerName] = std::make_shared( config_.vocab_size, config_.n_embd, parallel::global::GetSequenceParallelEnabled()); - // RoPE-based models do not use absolute position embedding. - if (config_.attention_type == AttentionType::kStandard) { + // Only learned absolute position embedding uses a trainable WPE table. + if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); + } else if (config_.position_embedding_type != PositionEmbeddingType::kRoPE) { + LOG(FATAL) << "Unsupported position embedding type"; } } @@ -45,7 +47,7 @@ std::vector> TransformerFirstStage::Forward(const std::v // (B, T) -> Embedding(V_local, C) -> (B, T, C) auto tok_emb = (*modules_[kWTELayerName])({x1}); - // Add position embedding only for models that use absolute position encoding + // Add position embedding only for models that use learned absolute position encoding. if (modules_.contains(kWPELayerName)) { // (T_local) // NOTE(zbl): Slice pos sequence when SP is enabled @@ -66,7 +68,7 @@ std::vector> TransformerFirstStage::Forward(const std::v // (B, T, C) return {tok_emb[0] + pos_emb[0]}; } else { - // For RoPE-based models (LLaMA3), no position embedding needed + // For RoPE-based models (LLaMA3), no absolute position embedding is needed. // (B, T, C) return tok_emb; } @@ -133,8 +135,8 @@ TransformerChunk::TransformerChunk(const TransformerConfig &config, int start_la std::vector> TransformerChunk::Forward(const std::vector> &x) { auto x1 = x[0]; - // Check if we need to pass RoPE parameters (for LLaMA3 style models) - if (config_.attention_type == AttentionType::kRoPE) { + // Check if we need to pass RoPE parameters (for LLaMA3 style models). + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { // For RoPE models, we need to prepare freqs_cis and potentially other parameters const auto device = x1->GetDevice(); @@ -163,9 +165,11 @@ std::vector> TransformerChunk::Forward(const std::vector for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; } - } else { - // Standard attention (GPT2 style) + } else if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { + // Learned absolute position embedding models (GPT-2 style). for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } + } else { + LOG(FATAL) << "Unsupported position embedding type"; } return {x1}; @@ -219,7 +223,7 @@ TransformerModel::TransformerModel(const TransformerConfig config) modules_[kPPFirstStageName] = std::make_shared(config_); transformer[TransformerFirstStage::kWTELayerName] = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWTELayerName); - if (config_.attention_type == AttentionType::kStandard) { + if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { transformer[TransformerFirstStage::kWPELayerName] = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWPELayerName); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index 9ff660a8..73c684f5 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -102,7 +102,7 @@ TEST_P(TransformerModuleTest, StandardAttention) { config.n_embd = 64; config.n_head = 4; config.n_kv_head = 4; - config.attention_type = nn::AttentionType::kStandard; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; config.add_bias_linear = true; auto attn = std::make_shared(config); @@ -120,7 +120,7 @@ TEST_P(TransformerModuleTest, MLAAttention) { config.n_embd = 64; config.n_head = 4; config.block_size = 16; - config.attention_type = nn::AttentionType::kStandard; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; config.add_bias_linear = true; config.multi_latent_attention = true; config.q_lora_rank = 32; @@ -198,7 +198,7 @@ TEST_P(TransformerModuleTest, LLaMA3Model) { config.n_head = 4; config.n_kv_head = 2; config.n_embd = 64; - config.attention_type = nn::AttentionType::kRoPE; + config.position_embedding_type = nn::PositionEmbeddingType::kRoPE; config.activation_type = nn::MLPType::kSwiGLU; config.norm_type = nn::NormType::kRMSNorm; config.add_bias_linear = false; @@ -225,7 +225,7 @@ TEST_P(TransformerModuleTest, StateDict) { config.n_kv_head = 2; config.n_embd = 32; config.vocab_size = 1000; - config.attention_type = nn::AttentionType::kStandard; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; config.activation_type = nn::MLPType::kGELU; config.norm_type = nn::NormType::kLayerNorm; config.add_bias_linear = true;