diff --git a/cpp/external/katagocoreml/src/Converter.cpp b/cpp/external/katagocoreml/src/Converter.cpp index cb6ca80d9..72b78e736 100644 --- a/cpp/external/katagocoreml/src/Converter.cpp +++ b/cpp/external/katagocoreml/src/Converter.cpp @@ -29,9 +29,12 @@ void KataGoConverter::convert(const std::string& input_path, throw std::invalid_argument("max_batch_size must be >= min_batch_size or <= 0 for unlimited"); } - // Parse KataGo model - KataGoParser parser(input_path); - KataGoModelDesc model = parser.parse(); + // Parse KataGo model (parser + its decompressed buffer freed at end of scope) + KataGoModelDesc model; + { + KataGoParser parser(input_path); + model = parser.parse(); + } // Determine if using FP16 precision bool use_fp16 = (options.compute_precision == "FLOAT16"); @@ -52,9 +55,8 @@ void KataGoConverter::convert(const std::string& input_path, options.use_fp16_io); auto program = builder.build(); - // Get weights from builder - auto weights = builder.getWeights(); - std::vector weights_copy(weights.begin(), weights.end()); + // Serialize directly from the builder's weight views (no copy). + std::vector& weights = builder.getWeightsMutable(); // Update options with model metadata for serialization ConversionOptions final_options = options; @@ -82,7 +84,7 @@ void KataGoConverter::convert(const std::string& input_path, // Serialize to .mlpackage CoreMLSerializer serializer(final_options.specification_version); - serializer.serialize(program.get(), weights_copy, output_path, final_options); + serializer.serialize(program.get(), weights, output_path, final_options); } ModelInfo KataGoConverter::getModelInfo(const std::string& input_path) { diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..f86181d87 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -4,12 +4,46 @@ #include "MILBuilder.hpp" #include "MILBlob/Fp16.hpp" #include +#include // Include generated protobuf headers #include "MIL.pb.h" namespace katagocoreml { +namespace { +// RAII: set a dtype slot to FLOAT32 for the current scope and restore it on exit. Used to emit a +// sub-region of ops in FP32 inside an otherwise-FP16 model. +struct ScopedFp32 { + CoreML::Specification::MILSpec::DataType& slot; + CoreML::Specification::MILSpec::DataType saved; + explicit ScopedFp32(CoreML::Specification::MILSpec::DataType& s) + : slot(s), saved(s) { s = CoreML::Specification::MILSpec::DataType::FLOAT32; } + ~ScopedFp32() { slot = saved; } + ScopedFp32(const ScopedFp32&) = delete; + ScopedFp32& operator=(const ScopedFp32&) = delete; +}; + +// True if any block in this list is a transformer (attention/FFN), recursing into nested-bottleneck +// blocks (which can themselves contain transformer blocks). Used to scope the off-ANE FP32 +// escalations to transformer trunks only. +bool blocksContainTransformer(const std::vector& blocks) { + for (const auto& entry : blocks) { + if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND || + entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + return true; + } + if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { + const auto& nbt = std::get(*entry.block); + if (blocksContainTransformer(nbt.blocks)) { + return true; + } + } + } + return false; +} +} // namespace + MILBuilder::MILBuilder(const KataGoModelDesc& model, int board_x_size, int board_y_size, @@ -30,7 +64,30 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, ? CoreML::Specification::MILSpec::DataType::FLOAT16 : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) - , m_var_counter(0) {} + , m_var_counter(0) { + // Precision in FP16 mode. The ANE accumulates FP16 in FP16, so any FP32 op runs OFF the FP16-only + // ANE (on CPU/GPU), breaking the ANE pipeline. These off-ANE FP32 escalations are applied ONLY to + // transformer trunks, whose attention blocks widen the activation range enough to overflow FP16 + // accumulation. Plain convnets stay PURE FP16 on the ANE -- the long-standing pre-tier path, verified + // to pass testgpuerror (b18c384nbt, b28c512nbt) and ~2.6x faster than forcing their per-block global + // pooling and convs to FP32 (measured: the per-block pooling round-trips, not the convs, dominate the + // slowdown). For transformers: + // - NARROW trunks (<256ch) build FULLY in FP32: their policy/value metrics sit right on the + // testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial FP32 leaves a + // noisy FP16 spatial stream). Off-ANE but cheap since narrow; equals the FP32 reference. Weights + // stored FP32 (per-weight serialization). + // - WIDER trunks use partial FP32: non-spatial (matmuls + pooling) always, convs only for >=320ch. + const int trunkChannels = model.trunk.trunk_num_channels; + const bool hasTransformer = blocksContainTransformer(model.trunk.blocks); + const bool full_fp32 = use_fp16 && hasTransformer && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + if (full_fp32) { + m_use_fp16 = false; + m_use_fp16_io = false; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + m_nonspatial_fp32 = m_use_fp16 && hasTransformer; + m_conv_fp32 = m_use_fp16 && hasTransformer && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; +} void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { auto* dim = tensor_type->add_dimensions(); @@ -212,9 +269,30 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage - m_ops.registerWeight(name, data, shape); + // Register weight for blob storage (non-owning view into the model). Mark FP32 storage when this + // const is declared FP32 (e.g. inside an FP32 sub-region of an otherwise-FP16 model) so storage + // matches the declared type. + m_ops.registerWeight(name, data, shape, + m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); + emitConstOp(block, name, shape); +} +void MILBuilder::addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) { + // Register derived/owned weight. Mirror addConstOp's per-weight FP32 marking: emitConstOp + // declares this const's dtype as m_weight_dtype, so the stored bytes must follow the same flag + // or BNNS rejects the model ("Metadata data type does not match requested type") when a derived + // const lands in an FP32 sub-region of an FP16 model. + const bool is_fp32 = (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); + m_ops.registerOwnedWeight(name, std::move(data), shape, is_fp32); + emitConstOp(block, name, shape); +} + +void MILBuilder::emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape) { // Add const operation auto* op = block->add_operations(); op->set_type("const"); @@ -328,7 +406,11 @@ void MILBuilder::addFloatScalarConstOp(CoreML::Specification::MILSpec::Block* bl val_type->set_datatype(m_weight_dtype); val_type->set_rank(0); - if (m_use_fp16) { + // Key the storage format off the DECLARED dtype (m_weight_dtype), not the global m_use_fp16: + // a temporarily-flipped FP32 sub-region (m_weight_dtype=FLOAT32 while m_use_fp16 stays true) + // must store FP32 floats, or CoreML rejects the model ("storage and type have different number + // of elements"). For all non-flipped calls m_weight_dtype tracks m_use_fp16, so this is a no-op. + if (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT16) { // For FP16, use bytes storage with FP16 representation MILBlob::Fp16 fp16_val = MILBlob::Fp16::FromFloat(value); std::string bytes_data(reinterpret_cast(&fp16_val.bytes), sizeof(fp16_val.bytes)); @@ -426,6 +508,68 @@ void MILBuilder::addCastOp(CoreML::Specification::MILSpec::Block* block, } } +std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims) { + std::string out = genVarName(input + "_cast"); + std::string dtName = out + "_dt"; + { + auto* op = block->add_operations(); + op->set_type("const"); + auto& na = (*op->mutable_attributes())["name"]; + na.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + na.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtName); + auto& va = (*op->mutable_attributes())["val"]; + va.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + va.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtype); + auto* o = op->add_outputs(); + o->set_name(dtName); + o->mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + } + auto* op = block->add_operations(); + op->set_type("cast"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(input); + (*op->mutable_inputs())["dtype"].add_arguments()->set_name(dtName); + auto* o = op->add_outputs(); + o->set_name(out); + auto* tt = o->mutable_type()->mutable_tensortype(); + tt->set_datatype(dtype == "fp32" ? CoreML::Specification::MILSpec::DataType::FLOAT32 + : CoreML::Specification::MILSpec::DataType::FLOAT16); + tt->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + if (d < 0) tt->add_dimensions()->mutable_unknown()->set_variadic(false); + else tt->add_dimensions()->mutable_constant()->set_size(d); + } + return out; +} + +void MILBuilder::addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant) { + auto pool = [&](const std::string& in, const std::string& msk, const std::string& out) { + if (valueVariant) addGlobalPoolingValueOps(block, in, msk, channels, out); + else addGlobalPoolingOps(block, in, msk, channels, out); + }; + // Non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum over H*W loses too much + // precision at larger board sizes). No addConstOp in the pooling, so flipping m_weight_dtype is + // safe. Cast input/mask up, pool in FP32, cast the [N, channels*3] features back to FP16. + if (!m_nonspatial_fp32) { + pool(input, mask, output); + return; + } + std::string in32 = castFixed(block, input, "fp32", {-1, channels, m_board_y_size, m_board_x_size}); + std::string mask32 = m_optimize_identity_mask + ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + std::string out32 = genVarName(output + "_f32"); + { ScopedFp32 g(m_weight_dtype); pool(in32, mask32, out32); } + addCastOp(block, out32, output, "fp16", {-1, channels * 3}); +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -566,6 +710,21 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, tt->add_dimensions()->mutable_constant()->set_size(4); } + // Channel-gated FP32 convs. The ANE accumulates FP16 convs in FP16, which loses too much + // precision for WIDE trunks and fails testgpuerror at large board sizes (validated: 384ch + // fails, <=256ch is fine FP16-on-ANE). For wide trunks (>= threshold) run convs in FP32 (weights + // cast up at runtime, stored fp16). FP32 convs can't run on the fp16-only ANE, so only the wide + // models that actually need it pay that off-ANE cost; narrow models keep convs on the ANE. + const bool convFp32 = m_conv_fp32; + std::string convX = input, convW = weight_name, convOut = output; + auto savedConvDtype = m_weight_dtype; + if (convFp32) { + convX = castFixed(block, input, "fp32", {-1, layer.in_channels, m_board_y_size, m_board_x_size}); + convW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + convOut = output + "_cf32"; + } + // Add conv operation referencing all const parameters auto* op = block->add_operations(); op->set_type("conv"); @@ -577,12 +736,12 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, inputs["pad"].add_arguments()->set_name(pad_name); inputs["pad_type"].add_arguments()->set_name(pad_type_name); inputs["strides"].add_arguments()->set_name(strides_name); - inputs["weight"].add_arguments()->set_name(weight_name); - inputs["x"].add_arguments()->set_name(input); + inputs["weight"].add_arguments()->set_name(convW); + inputs["x"].add_arguments()->set_name(convX); // Output with dimensions [batch, out_channels, height, width] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(convOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(4); @@ -590,6 +749,11 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + + if (convFp32) { + m_weight_dtype = savedConvDtype; + addCastOp(block, convOut, output, "fp16", {-1, layer.out_channels, m_board_y_size, m_board_x_size}); + } } // Helper: Set output tensor type with 4D shape [batch, C, H, W] @@ -732,6 +896,63 @@ void MILBuilder::addBatchNormActivationOps(CoreML::Specification::MILSpec::Block setTensorOutput4D(op, output, bn.num_channels, m_board_y_size, m_board_x_size); } else if (act.activation_type == ActivationType::Mish) { addMishOps(block, bn_output, output, 4, bn.num_channels); + } else if (act.activation_type == ActivationType::Silu) { + addSiluOps(block, bn_output, output, 4, bn.num_channels); + } +} + +void MILBuilder::addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels) { + // SiLU / Swish: x * sigmoid(x) + auto setOutputType = [this, rank, channels](CoreML::Specification::MILSpec::Operation* op, const std::string& name) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* out_type = out->mutable_type()->mutable_tensortype(); + out_type->set_datatype(m_weight_dtype); + out_type->set_rank(rank); + setBatchDimension(out_type); + out_type->add_dimensions()->mutable_constant()->set_size(channels); + if (rank == 4) { + out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); + out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + } + }; + + std::string sig = output + "_sigmoid"; + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + setOutputType(op, sig); + } + { + auto* op = block->add_operations(); + op->set_type("mul"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + inputs["y"].add_arguments()->set_name(sig); + setOutputType(op, output); + } +} + +void MILBuilder::setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* t = out->mutable_type()->mutable_tensortype(); + t->set_datatype(m_weight_dtype); + t->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + auto* dim = t->add_dimensions(); + if (d < 0) + dim->mutable_unknown()->set_variadic(false); + else + dim->mutable_constant()->set_size(d); } } @@ -887,23 +1108,38 @@ void MILBuilder::addMatMulOp(CoreML::Specification::MILSpec::Block* block, CoreML::Specification::MILSpec::DataType::BOOL); } + // Non-spatial matmul in FP32 (KataGo FP16 convention; weights cast up at runtime, stored fp16). + std::string mmIn = input, mmW = weight_name, mmOut = output; + auto savedMmDtype = m_weight_dtype; + if (m_nonspatial_fp32) { + mmIn = castFixed(block, input, "fp32", {-1, layer.in_channels}); + mmW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + mmOut = output + "_mmf32"; + } + // Add matmul operation auto* op = block->add_operations(); op->set_type("matmul"); auto& inputs = *op->mutable_inputs(); inputs["transpose_x"].add_arguments()->set_name(transpose_x_name); inputs["transpose_y"].add_arguments()->set_name(transpose_y_name); - inputs["x"].add_arguments()->set_name(input); - inputs["y"].add_arguments()->set_name(weight_name); + inputs["x"].add_arguments()->set_name(mmIn); + inputs["y"].add_arguments()->set_name(mmW); // Output with 2D shape [batch, out_channels] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(mmOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(2); setBatchDimension(out_type); out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); + + if (m_nonspatial_fp32) { + m_weight_dtype = savedMmDtype; + addCastOp(block, mmOut, output, "fp16", {-1, layer.out_channels}); + } } void MILBuilder::addMatBiasOp(CoreML::Specification::MILSpec::Block* block, @@ -958,13 +1194,15 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, // Add transposed weight constant with shape [out_channels, in_channels] std::vector transposed_shape = {static_cast(out_ch), static_cast(in_ch)}; - addConstOp(block, weight_name, transposed_weights, transposed_shape); + addOwnedConstOp(block, weight_name, std::move(transposed_weights), transposed_shape); // Add bias constant std::vector bias_shape = {static_cast(bias.num_channels)}; addConstOp(block, bias_name, bias.weights, bias_shape); - // Add linear operation + // NOTE: the MIL `linear` op requires const weight/bias, so the runtime-cast-to-FP32 trick can't + // be applied here (unlike `matmul`). Value-head linear stays FP16; if a model ever needs it in + // FP32, rewrite as matmul+add (matmul accepts cast inputs). auto* op = block->add_operations(); op->set_type("linear"); auto& inputs = *op->mutable_inputs(); @@ -1637,6 +1875,641 @@ void MILBuilder::addGlobalPoolingValueOps(CoreML::Specification::MILSpec::Block* // Network Component Builders // ============================================================================ +// --------------------------------------------------------------------------- +// Transformer blocks (MIL). Layout is NCHW [B, C, H, W]; spatial positions +// (H*W, ordered y*W+x) are treated as the attention sequence. RoPE is applied +// via a fixed pair-rotation matmul plus host-precomputed cos/sin tables, which +// keeps every tensor rank <= 4 (ANE-friendly). +// --------------------------------------------------------------------------- + +std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + // RMSNorm reduction core: square -> mean over channels -> rsqrt. In FP16 mode compute this + // core in FP32 (cast input up, flip the working dtype so the core's op outputs + eps scalar are + // FP32, then cast 1/rms back down). The FP16 channel reduction loses too much precision on the + // ANE; only this core is FP32 - the scaling/weight/mask below stay FP16. No addConstOp lives in + // the flipped window, so weight serialization is unaffected. + auto savedDtype = m_weight_dtype; + std::string sqSrc = input; + if (m_use_fp16) { + sqSrc = genVarName(prefix + "_in32"); + addCastOp(block, input, sqSrc, "fp32", {-1, C, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", sqSrc, sqSrc, sq, {-1, C, H, W}); + // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so + // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can + // overflow FP16 (and the FP16 accumulation on ANE) for large activations. + std::string meanSq = genVarName(prefix + "_meansq"); + { + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + } + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, {-1, 1, H, W}); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", {-1, 1, H, W}); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string weightName = prefix + "_weight"; + addConstOp(block, weightName, desc.weight, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, weightName, scaled, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + emit2("mul", scaled, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + auto reduceSum = [&](const std::string& x, const std::string& out, const std::vector& axes, + const std::vector& dims) { + std::string axesName = out + "_axes"; + std::string keepName = out + "_keep"; + addIntArrayConstOp(block, axesName, axes); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_sum"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, out, dims); + }; + + // Variance core (mask -> square -> reduce -> rsqrt) in FP32 when in FP16 mode. The trunk-tip + // norm in particular reduces over many elements and loses too much precision in FP16 on the + // ANE; compute the core in FP32 and cast 1/rms back to FP16. Only the core is FP32 - gamma/beta, + // the activation and the final mask below stay FP16. No addConstOp lives in the flipped window. + auto savedDtype = m_weight_dtype; + std::string tinput = input; + std::string tmask = mask; + if (m_use_fp16) { + tinput = genVarName(prefix + "_in32"); + addCastOp(block, input, tinput, "fp32", {-1, C, H, W}); + tmask = genVarName(prefix + "_mask32"); + addCastOp(block, mask, tmask, "fp32", {-1, 1, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string masked = genVarName(prefix + "_premask"); + emit2("mul", tinput, tmask, masked, {-1, C, H, W}); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", masked, masked, sq, {-1, C, H, W}); + + std::string meanSq; + std::vector denomDims; + if (desc.spatial) { + // Mean of squares over valid positions and channels. A reduce_sum over C*H*W elements + // overflows FP16 (e.g. trunk tip with large activations on ANE -> inf -> rsqrt 0 -> + // collapse). Instead take reduce_mean over all of C,H,W (masked positions are zero) and + // rescale by totalPositions/validCount to restrict the mean to valid positions. + std::string meanAll = genVarName(prefix + "_meanall"); + { + std::string axesName = meanAll + "_axes", keepName = meanAll + "_keep"; + addIntArrayConstOp(block, axesName, {1, 2, 3}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanAll, {-1, 1, 1, 1}); + } + std::string count = genVarName(prefix + "_count"); + reduceSum(tmask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + std::string totalPosName = prefix + "_totalpos"; + addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); + std::string scaleF = genVarName(prefix + "_scalef"); + emit2("real_div", totalPosName, count, scaleF, {-1, 1, 1, 1}); // totalPos / validCount + meanSq = genVarName(prefix + "_meansq"); + emit2("mul", meanAll, scaleF, meanSq, {-1, 1, 1, 1}); + denomDims = {-1, 1, 1, 1}; + } else { + meanSq = genVarName(prefix + "_meansq"); + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + denomDims = {-1, 1, H, W}; + } + + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, denomDims); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", denomDims); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string gammaName = prefix + "_gamma"; + std::string betaName = prefix + "_beta"; + addConstOp(block, gammaName, desc.gamma, {1, static_cast(C), 1, 1}); + addConstOp(block, betaName, desc.beta, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, gammaName, scaled, {-1, C, H, W}); + std::string biased = genVarName(prefix + "_biased"); + emit2("add", scaled, betaName, biased, {-1, C, H, W}); + + std::string activated; + if (act.activation_type == ActivationType::Silu) { + activated = genVarName(prefix + "_act"); + addSiluOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::Mish) { + activated = genVarName(prefix + "_act"); + addMishOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::ReLU) { + activated = genVarName(prefix + "_act"); + auto* op = block->add_operations(); + op->set_type("relu"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(biased); + setShape(op, activated, {-1, C, H, W}); + } else { + activated = biased; + } + std::string out = genVarName(prefix + "_out"); + emit2("mul", activated, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.q_proj.in_channels; + const int H = m_board_y_size, W = m_board_x_size; + const int seq = H * W; + const int numHeads = desc.num_heads, numKVHeads = desc.num_kv_heads; + const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; + const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims, bool transX, bool transY) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, transX); + addBoolScalarConstOp(block, tyName, transY); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + // Q/K/V projection matmuls in FP32 (non-spatial, per KataGo's FP16 convention): they reduce over + // C channels and the ANE's FP16 accumulation loses too much precision for wide models. Weights + // stay fp16-stored (cast up at runtime); output cast back to FP16 for the FP16 head reshapes. + auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { + std::string wName = nm + "_w"; + addConstOp(block, wName, w.weights, w.getWeightShape()); + std::string out = genVarName(nm); + if (m_nonspatial_fp32) { + std::string x32 = castFixed(block, x2d, "fp32", {-1, C}); + std::string w32 = castFixed(block, wName, "fp32", w.getWeightShape()); + auto sd = m_weight_dtype; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string o32 = genVarName(nm + "_f32"); + matmul(x32, w32, o32, {-1, total}, false, false); + m_weight_dtype = sd; + out = castFixed(block, o32, "fp16", {-1, total}); + } else { + matmul(x2d, wName, out, {-1, total}, false, false); + } + return out; + }; + std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); + std::string k2d = proj(desc.k_proj, prefix + "_k", kTotal); + std::string v2d = proj(desc.v_proj, prefix + "_v", vTotal); + auto toHeads = [&](const std::string& in2d, const std::string& nm, int nh, int hd) { + std::string r = genVarName(nm + "_r"); + reshape(in2d, r, {-1, seq, nh, hd}, {-1, seq, nh, hd}); + std::string t = genVarName(nm + "_t"); + transpose(r, t, {0, 2, 1, 3}, {-1, nh, seq, hd}); + return t; + }; + std::string qh = toHeads(q2d, prefix + "_qh", numHeads, qHeadDim); + std::string kh = toHeads(k2d, prefix + "_kh", numKVHeads, qHeadDim); + std::string vh = toHeads(v2d, prefix + "_vh", numKVHeads, vHeadDim); + + if (desc.use_rope) { + const int numPairs = qHeadDim / 2; + const int numPairsPerDim = numPairs / 2; + const int dimHalf = qHeadDim / 2; + auto applyRope = [&](const std::string& x, int nh, const std::string& tag) { + std::vector cosFull(static_cast(nh) * seq * qHeadDim, 0.0f); + std::vector sinFull(static_cast(nh) * seq * qHeadDim, 0.0f); + for (int h = 0; h < nh; h++) { + int kvh = (h * numKVHeads) / nh; + for (int xy = 0; xy < seq; xy++) { + int y = xy / W; + int x = xy % W; + for (int p = 0; p < numPairs; p++) { + float angle = 0.0f; + if (desc.learnable_rope) { + float fx = desc.rope_freqs[(kvh * numPairs + p) * 2 + 0]; + float fy = desc.rope_freqs[(kvh * numPairs + p) * 2 + 1]; + angle = static_cast(x) * fx + static_cast(y) * fy; + } else { + if (p < numPairsPerDim) { + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * p) / dimHalf); + angle = static_cast(y) * freq; + } else { + int pAdj = p - numPairsPerDim; + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * pAdj) / dimHalf); + angle = static_cast(x) * freq; + } + } + float c = std::cos(angle), s = std::sin(angle); + size_t base = (static_cast(h) * seq + xy) * qHeadDim + 2 * p; + cosFull[base] = c; cosFull[base + 1] = c; + sinFull[base] = s; sinFull[base + 1] = s; + } + } + } + std::vector R(static_cast(qHeadDim) * qHeadDim, 0.0f); + for (int p = 0; p < numPairs; p++) { + R[(2 * p) * qHeadDim + (2 * p + 1)] = 1.0f; + R[(2 * p + 1) * qHeadDim + (2 * p)] = -1.0f; + } + std::string cosName = prefix + "_" + tag + "_cos"; + std::string sinName = prefix + "_" + tag + "_sin"; + std::string rName = prefix + "_" + tag + "_R"; + // cosFull/sinFull/R are locals computed here, so register them as OWNED consts: the + // WeightEntry holds a non-owning FloatView and serialization runs after this lambda + // returns, so a non-owning addConstOp would dangle. + addOwnedConstOp(block, cosName, std::move(cosFull), {1, nh, seq, qHeadDim}); + addOwnedConstOp(block, sinName, std::move(sinFull), {1, nh, seq, qHeadDim}); + // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. + addOwnedConstOp(block, rName, std::move(R), {1, 1, qHeadDim, qHeadDim}); + std::string rotated = genVarName(prefix + "_" + tag + "_rot"); + matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); + std::string xc = genVarName(prefix + "_" + tag + "_xc"); + binary("mul", x, cosName, xc, {-1, nh, seq, qHeadDim}); + std::string rs = genVarName(prefix + "_" + tag + "_rs"); + binary("mul", rotated, sinName, rs, {-1, nh, seq, qHeadDim}); + std::string out = genVarName(prefix + "_" + tag + "_rope"); + binary("add", xc, rs, out, {-1, nh, seq, qHeadDim}); + return out; + }; + qh = applyRope(qh, numHeads, "q"); + kh = applyRope(kh, numKVHeads, "k"); + } + + // GQA: when numKVHeads < numHeads, repeat each KV head groupSize times along the head + // axis (axis 1) so query head h consumes kv head (h / groupSize). RoPE has already been + // applied above to the unexpanded kh (kh = applyRope(kh, numKVHeads, "k")), mirroring the + // GPU path (metallayers.swift repeatKVHeads runs AFTER applyRope). We slice each KV head + // and concat its copies consecutively, so the resulting head index is kv*groupSize + g; + // query head h then maps to kv = h/groupSize == (h*numKVHeads)/numHeads (exact divisor, + // the same formula the qh RoPE table uses) == Eigen's kvh = h/kvGroupSize. slice_by_size + + // concat (not reshape+broadcast) avoids the dynamic -1 batch broadcast pitfall, same as the + // GPU code. The repeat is required so the scores (qh@kh^T) and attnOut (attn@vh) matmuls see + // matching [B,numHeads,...] batch dims instead of numHeads vs numKVHeads (no broadcast). + if (numKVHeads != numHeads) { + const int groupSize = numHeads / numKVHeads; + auto repeatKVHeads = [&](const std::string& x, const std::string& tag, int headDim) { + std::vector parts; + parts.reserve(static_cast(numKVHeads) * groupSize); + for (int kv = 0; kv < numKVHeads; kv++) { + for (int g = 0; g < groupSize; g++) { + std::string part = genVarName(prefix + "_" + tag + "_slc"); + std::string beginName = part + "_begin", sizeName = part + "_size"; + addIntArrayConstOp(block, beginName, {0, kv, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, headDim}); + auto* sop = block->add_operations(); + sop->set_type("slice_by_size"); + (*sop->mutable_inputs())["x"].add_arguments()->set_name(x); + (*sop->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*sop->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(sop, part, {-1, 1, seq, headDim}); + parts.push_back(part); + } + } + std::string out = genVarName(prefix + "_" + tag + "_exp"); + std::string axisName = out + "_axis", interleaveName = out + "_interleave"; + addIntScalarConstOp(block, axisName, 1); + addBoolScalarConstOp(block, interleaveName, false); + auto* cop = block->add_operations(); + cop->set_type("concat"); + auto& cin = *cop->mutable_inputs(); + for (const std::string& part : parts) + cin["values"].add_arguments()->set_name(part); + cin["axis"].add_arguments()->set_name(axisName); + cin["interleave"].add_arguments()->set_name(interleaveName); + setShape(cop, out, {-1, numHeads, seq, headDim}); + return out; + }; + kh = repeatKVHeads(kh, "khrep", qHeadDim); + vh = repeatKVHeads(vh, "vhrep", vHeadDim); + } + + std::string scores = genVarName(prefix + "_scores"); + matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); + std::string scaleName = prefix + "_scale"; + addFloatScalarConstOp(block, scaleName, 1.0f / std::sqrt(static_cast(qHeadDim))); + std::string scaled = genVarName(prefix + "_sc"); + binary("mul", scores, scaleName, scaled, {-1, numHeads, seq, seq}); + + // mask [B,1,H,W] -> [B,1,1,seq] directly (contiguous reshape; H,W already trailing so the + // row-major flatten gives seq index xy=y*W+x). No transpose -> avoids the reshape-after- + // transpose issue, and is also correct for non-full boards. + std::string maskSeq = genVarName(prefix + "_mseq"); + reshape(mask, maskSeq, {-1, 1, 1, seq}, {-1, 1, 1, seq}); + std::string oneName = prefix + "_one"; + addFloatScalarConstOp(block, oneName, 1.0f); + std::string mm1 = genVarName(prefix + "_mm1"); + binary("sub", maskSeq, oneName, mm1, {-1, 1, 1, seq}); + // Use an FP16-safe magnitude: 1e9 overflows FP16 to +inf, and for valid keys + // (maskSeq-1 == 0) the product 0 * inf becomes NaN, poisoning the whole softmax. + // 1e4 is well within FP16 range and exp(score - 1e4) still underflows to 0. + std::string bigName = prefix + "_big"; + addFloatScalarConstOp(block, bigName, 1.0e4f); + std::string keyBias = genVarName(prefix + "_kb"); + binary("mul", mm1, bigName, keyBias, {-1, 1, 1, seq}); + std::string scoresMasked = genVarName(prefix + "_scm"); + binary("add", scaled, keyBias, scoresMasked, {-1, numHeads, seq, seq}); + + std::string attn = genVarName(prefix + "_attn"); + { + std::string axisName = attn + "_axis"; + addIntScalarConstOp(block, axisName, 3); + auto* op = block->add_operations(); + op->set_type("softmax"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(scoresMasked); + (*op->mutable_inputs())["axis"].add_arguments()->set_name(axisName); + setShape(op, attn, {-1, numHeads, seq, seq}); + } + + std::string attnOut = genVarName(prefix + "_ao"); + matmul(attn, vh, attnOut, {-1, numHeads, seq, vHeadDim}, false, false); + + // Output projection, done per-head to avoid reshape-after-transpose: CoreML's reshape + // ignores an immediately-preceding transpose, so merging [head,dim]->channels after a + // transpose scrambles the data. Instead slice each head from attnOut (head is the + // contiguous axis 1), reshape (leading-merge only), matmul its weight slice, and sum. + // out[b,s,c] = sum_h sum_d attnOut[b,h,s,d] * outProj.weights[(h*vHeadDim+d)*outC + c] + const int outC = desc.out_proj.out_channels; + std::string proj2d; + for (int h = 0; h < numHeads; h++) { + std::string aoh = genVarName(prefix + "_aoh"); + { + std::string beginName = aoh + "_begin", sizeName = aoh + "_size"; + addIntArrayConstOp(block, beginName, {0, h, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, vHeadDim}); + auto* op = block->add_operations(); + op->set_type("slice_by_size"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(attnOut); + (*op->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*op->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(op, aoh, {-1, 1, seq, vHeadDim}); + } + std::string aoh2d = genVarName(prefix + "_aoh2d"); + reshape(aoh, aoh2d, {-1, vHeadDim}, {-1, vHeadDim}); // [B*seq, vHeadDim] + std::string wh = prefix + "_ow" + std::to_string(h); + std::vector whData(static_cast(vHeadDim) * outC); + for (int d = 0; d < vHeadDim; d++) + for (int c = 0; c < outC; c++) + whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; + // whData is a per-head local slice; register OWNED so its FloatView stays valid until + // serialization (a non-owning addConstOp would dangle after this loop iteration). + addOwnedConstOp(block, wh, std::move(whData), {vHeadDim, outC}); + std::string contrib = genVarName(prefix + "_contrib"); + matmul(aoh2d, wh, contrib, {-1, outC}, false, false); + if (h == 0) { + proj2d = contrib; + } else { + std::string acc = genVarName(prefix + "_acc"); + binary("add", proj2d, contrib, acc, {-1, outC}); + proj2d = acc; + } + } + std::string projNHWC = genVarName(prefix + "_pnhwc"); + reshape(proj2d, projNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string projNCHW = genVarName(prefix + "_pnchw"); + transpose(projNHWC, projNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", projNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int ffn = desc.ffn_channels; + const int H = m_board_y_size, W = m_board_x_size; + + if (!desc.use_swiglu) { + throw std::runtime_error(desc.name + ": non-SwiGLU transformer FFN not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, false); + addBoolScalarConstOp(block, tyName, false); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + + // FFN matmuls in FP32 (weights cast up at runtime, stored fp16) — KataGo's FP16 convention is + // spatial(convs)=FP16, non-spatial(matmuls)=FP32 (see openclbackend.cpp). The ANE accumulates + // FP16 matmuls in FP16, which loses too much precision over C/ffn; run them in FP32 instead. + std::string w1 = prefix + "_w1"; + addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); + std::string wg = prefix + "_wg"; + addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + + auto savedDtype = m_weight_dtype; + std::string mx2d = x2d, mw1 = w1, mwg = wg, mw2 = w2; + if (m_nonspatial_fp32) { + mx2d = castFixed(block, x2d, "fp32", {-1, C}); + mw1 = castFixed(block, w1, "fp32", desc.linear1.getWeightShape()); + mwg = castFixed(block, wg, "fp32", desc.linear_gate.getWeightShape()); + mw2 = castFixed(block, w2, "fp32", desc.linear2.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string a = genVarName(prefix + "_a"); + matmul(mx2d, mw1, a, {-1, ffn}); + std::string g = genVarName(prefix + "_g"); + matmul(mx2d, mwg, g, {-1, ffn}); + + std::string sig = genVarName(prefix + "_sig"); + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(a); + setShape(op, sig, {-1, ffn}); + } + std::string siluA = genVarName(prefix + "_silu"); + binary("mul", a, sig, siluA, {-1, ffn}); + std::string h = genVarName(prefix + "_h"); + binary("mul", siluA, g, h, {-1, ffn}); + + std::string oCore = genVarName(prefix + "_o"); + matmul(h, mw2, oCore, {-1, C}); + std::string o = oCore; + if (m_nonspatial_fp32) { + m_weight_dtype = savedDtype; + o = castFixed(block, oCore, "fp16", {-1, C}); + } + + std::string oNHWC = genVarName(prefix + "_onhwc"); + reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string oNCHW = genVarName(prefix + "_onchw"); + transpose(oNHWC, oNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", oNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, const std::string& spatial_input, const std::string& global_input, @@ -1747,12 +2620,23 @@ std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, } else if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { const auto& block_desc = std::get(*entry.block); x = buildNestedBottleneckBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, block_desc, mask, prefix); } } // Trunk tip - std::string trunk_out = genVarName("trunk_tip"); - addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + std::string trunk_out; + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk_out = genVarName("trunk_tip"); + addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + } else { + trunk_out = addTrunkRMSNorm(block, x, trunk.trunk_tip_rms_norm, trunk.trunk_tip_activation, mask, "trunk_tip_rms"); + } return trunk_out; } @@ -1814,9 +2698,11 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling + // Global pooling (FP32 when m_nonspatial_fp32 -- see addGlobalPoolingFp32). Feeds a bias back + // into the whole trunk, so the FP16 spatial sum must not lose precision for wide trunks. std::string gpool_features = genVarName(prefix + "_gpool_features"); - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + addGlobalPoolingFp32(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features, + /*valueVariant=*/false); // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -1898,6 +2784,12 @@ std::string MILBuilder::buildNestedBottleneckBlock(CoreML::Specification::MILSpe } else if (entry.block_kind == GLOBAL_POOLING_BLOCK_KIND) { const auto& nested = std::get(*entry.block); x = buildGlobalPoolingResidualBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, nested, mask, nested_prefix); } } @@ -1942,9 +2834,9 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 + // Global pooling on G1 (FP32 when m_nonspatial_fp32; feeds the policy bias / policyKLDiv). std::string g1_pooled = genVarName("policy_g1_pool"); - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + addGlobalPoolingFp32(block, g1, mask, ph.g1_conv.out_channels, g1_pooled, /*valueVariant=*/false); // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2002,6 +2894,8 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, pass_activated, ph.gpool_to_pass_mul.out_channels); } else if (ph.pass_activation->activation_type == ActivationType::Mish) { addMishOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); + } else if (ph.pass_activation->activation_type == ActivationType::Silu) { + addSiluOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); } else { pass_activated = pass_biased; } @@ -2032,9 +2926,9 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) + // Global pooling (value head version; FP32 when m_nonspatial_fp32). std::string v1_pooled = genVarName("value_v1_pool"); - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + addGlobalPoolingFp32(block, v1, mask, vh.v1_conv.out_channels, v1_pooled, /*valueVariant=*/true); // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); @@ -2049,6 +2943,8 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, v2, vh.v2_mul.out_channels); } else if (vh.v2_activation.activation_type == ActivationType::Mish) { addMishOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); + } else if (vh.v2_activation.activation_type == ActivationType::Silu) { + addSiluOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); } else { v2 = v2_bias; } @@ -2085,6 +2981,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act1, encoder.mul1.out_channels); } else if (encoder.act1.activation_type == ActivationType::Mish) { addMishOps(block, bias1, act1, 2, encoder.mul1.out_channels); + } else if (encoder.act1.activation_type == ActivationType::Silu) { + addSiluOps(block, bias1, act1, 2, encoder.mul1.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); @@ -2107,6 +3005,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act2, encoder.mul2.out_channels); } else if (encoder.act2.activation_type == ActivationType::Mish) { addMishOps(block, bias2, act2, 2, encoder.mul2.out_channels); + } else if (encoder.act2.activation_type == ActivationType::Silu) { + addSiluOps(block, bias2, act2, 2, encoder.mul2.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..6897f39a1 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -29,8 +29,8 @@ class MILBuilder { /// @return Unique pointer to MIL Program protobuf std::unique_ptr build(); - /// Get weight entries for blob serialization - const std::vector& getWeights() const { return m_ops.getWeights(); } + /// Get weight entries for blob serialization (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_ops.getWeightsMutable(); } /// Get board dimensions int getBoardXSize() const { return m_board_x_size; } @@ -43,6 +43,16 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; + // FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer + // trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation). + // Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass + // testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32. + // For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls + + // pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; @@ -80,6 +90,24 @@ class MILBuilder { const std::vector& data, const std::vector& shape); + // addConstOp registers a NON-OWNING view into `data` (see WeightEntry), so the + // backing storage must outlive serialization. Binding a temporary here would + // dangle. Deleted so such calls fail to compile; use addOwnedConstOp for + // derived/temporary tensors that KataGoOps should own instead. + void addConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + + void addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape); + + void emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape); + void addIntArrayConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& values); @@ -102,6 +130,23 @@ class MILBuilder { const std::string& dtype, const std::vector& shape); + // Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for + // weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the + // new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim. + std::string castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims); + + // Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool, + // cast the pooled features back to FP16). valueVariant selects the value-head pooling variant. + void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -120,6 +165,44 @@ class MILBuilder { int rank, int channels); + void addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels); + + // Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions. + void setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims); + + // Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out. + std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix); + + // Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out. + std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block, const std::string& input, const std::string& mask, diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..e86364943 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -14,12 +14,30 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { WeightEntry entry; entry.name = name; - entry.data = data; + entry.data = FloatView{data.data(), data.size()}; entry.shape = shape; entry.blob_offset = 0; // Will be set during serialization + entry.is_fp32 = is_fp32; + m_weights.push_back(std::move(entry)); + return name; +} + +std::string KataGoOps::registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape, + bool is_fp32) { + m_owned.push_back(std::move(data)); + const std::vector& stored = m_owned.back(); + WeightEntry entry; + entry.name = name; + entry.data = FloatView{stored.data(), stored.size()}; + entry.shape = shape; + entry.blob_offset = 0; + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..385648d19 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -5,17 +5,33 @@ #include "../types/KataGoTypes.hpp" #include +#include #include #include namespace katagocoreml { -/// Weight entry for blob file storage +/// Minimal non-owning view over a contiguous float buffer. KataGo-local on +/// purpose: keeps the MILBlob dependency out of this header (conversion to +/// MILBlob::Util::Span happens only at the serializer boundary). +struct FloatView { + const float* ptr = nullptr; + size_t len = 0; + const float* data() const { return ptr; } + size_t size() const { return len; } + bool empty() const { return len == 0; } + float operator[](size_t i) const { return ptr[i]; } +}; + +/// Weight entry for blob file storage. `data` is a NON-OWNING view into the live +/// KataGoModelDesc (or into KataGoOps::m_owned for derived tensors). struct WeightEntry { std::string name; - std::vector data; + FloatView data; // non-owning view (replaces raw ptr + count) std::vector shape; uint64_t blob_offset = 0; // Set during serialization + bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an + // FP32 sub-region of an otherwise-FP16 model). Else stored per global mode. }; /// Precomputed constants for identity mask optimization @@ -51,16 +67,33 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight that lives in the model (stored as a non-owning view). + /// is_fp32 marks it for FP32 storage. std::string registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); + + /// The stored WeightEntry is a non-owning view into `data`, so a temporary + /// would leave it dangling. Deleted to reject such calls at compile time; + /// use registerOwnedWeight for tensors KataGoOps should own. + std::string registerWeight(const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + + /// Register a derived/temporary weight; KataGoOps takes ownership so the + /// view stays valid through serialization. is_fp32 marks it for FP32 storage + /// (mirrors registerWeight) so the stored dtype matches the declared const dtype. + std::string registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape, + bool is_fp32 = false); - /// Get all registered weights - const std::vector& getWeights() const { return m_weights; } + /// Get all registered weights (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_weights; } - /// Clear all registered weights - void clearWeights() { m_weights.clear(); } + /// Clear all registered weights (and their owned backing buffers) + void clearWeights() { m_weights.clear(); m_owned.clear(); } /// Generate unique operation name std::string genOpName(const std::string& prefix); @@ -71,6 +104,7 @@ class KataGoOps { bool m_optimize_identity_mask; MaskConstants m_mask_constants; std::vector m_weights; + std::deque> m_owned; int m_op_counter = 0; }; diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 68f1a0e56..b7a662afd 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -30,54 +29,41 @@ bool KataGoParser::isVersionSupported(int version) { } // ============================================================================ -// File Loading +// Stream Primitives // ============================================================================ -void KataGoParser::loadFile() { - // Check if gzip compressed - bool is_gzip = false; - if (m_model_path.size() >= 3) { - std::string ext = m_model_path.substr(m_model_path.size() - 3); - is_gzip = (ext == ".gz"); - } - - if (is_gzip) { - // Read gzipped file - gzFile gz = gzopen(m_model_path.c_str(), "rb"); - if (!gz) { - throw std::runtime_error("Cannot open gzip file: " + m_model_path); - } - - // Read in chunks - m_buffer.clear(); - std::vector chunk(1024 * 1024); // 1MB chunks - int bytes_read; - while ((bytes_read = gzread(gz, chunk.data(), static_cast(chunk.size()))) > 0) { - m_buffer.insert(m_buffer.end(), chunk.begin(), chunk.begin() + bytes_read); - } - - if (bytes_read < 0) { - int errnum; - const char* errmsg = gzerror(gz, &errnum); - gzclose(gz); - throw std::runtime_error("Error reading gzip file: " + std::string(errmsg)); - } - - gzclose(gz); - } else { - // Read regular file - std::ifstream file(m_model_path, std::ios::binary | std::ios::ate); - if (!file) { - throw std::runtime_error("Cannot open file: " + m_model_path); - } +bool KataGoParser::refill() { + if(!m_gz) return false; + int n = gzread(m_gz.get(), m_refill.data(), (unsigned)m_refill.size()); + if(n < 0) { + int errnum; + const char* errmsg = gzerror(m_gz.get(), &errnum); + throw std::runtime_error("Error reading gzip stream: " + std::string(errmsg)); + } + m_refillPos = 0; + m_refillLen = (size_t)n; + return n > 0; +} - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); +int KataGoParser::peekByte() { + if(m_refillPos >= m_refillLen) { + if(!refill()) return -1; + } + return (int)m_refill[m_refillPos]; +} - m_buffer.resize(static_cast(size)); - if (!file.read(reinterpret_cast(m_buffer.data()), size)) { - throw std::runtime_error("Error reading file: " + m_model_path); +void KataGoParser::readExact(uint8_t* dst, size_t n, const std::string& name) { + size_t got = 0; + while(got < n) { + if(m_refillPos >= m_refillLen) { + if(!refill()) + throw std::runtime_error(name + ": unexpected EOF in binary block"); } + size_t avail = m_refillLen - m_refillPos; + size_t take = std::min(avail, n - got); + std::memcpy(dst + got, m_refill.data() + m_refillPos, take); + m_refillPos += take; + got += take; } } @@ -86,15 +72,16 @@ void KataGoParser::loadFile() { // ============================================================================ KataGoModelDesc KataGoParser::parse() { - loadFile(); - m_pos = 0; - - // Detect if binary format (check for @BIN@ marker) - const std::string bin_marker = "@BIN@"; - auto it = std::search(m_buffer.begin(), m_buffer.end(), - bin_marker.begin(), bin_marker.end()); - m_binary_floats = (it != m_buffer.end()); - + // Allocate the refill buffer first; if this throws, no handle has been opened. + m_refill.resize(1024 * 1024); + m_gz.reset(gzopen(m_model_path.c_str(), "rb")); + if(!m_gz) + throw std::runtime_error("Cannot open file: " + m_model_path); + m_refillPos = 0; + m_refillLen = 0; + m_formatDetected = false; // decided at first readFloats + m_binary_floats = true; + // ~GzHandle closes the file on normal return OR exception — no try/catch needed. return parseModel(); } @@ -103,24 +90,20 @@ KataGoModelDesc KataGoParser::parse() { // ============================================================================ void KataGoParser::skipWhitespace() { - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c != ' ' && c != '\t' && c != '\n' && c != '\r') { - break; - } - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c != ' ' && c != '\t' && c != '\n' && c != '\r') break; + m_refillPos++; } } void KataGoParser::readUntilWhitespace(std::string& out) { out.clear(); - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { - break; - } - out += c; - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c == ' ' || c == '\t' || c == '\n' || c == '\r') break; + out += (char)c; + m_refillPos++; } } @@ -147,37 +130,28 @@ bool KataGoParser::readBool() { std::vector KataGoParser::readFloats(size_t count, const std::string& name) { std::vector floats(count); + skipWhitespace(); - if (!m_binary_floats) { + // KataGo model files are uniformly text OR uniformly binary, so detecting the + // format once at the first weight block (binary blocks start with '@BIN@') + // is valid for all subsequent blocks. + if(!m_formatDetected) { + m_binary_floats = (peekByte() == '@'); + m_formatDetected = true; + } + + if(!m_binary_floats) { // Text format - for (size_t i = 0; i < count; i++) { + for(size_t i = 0; i < count; i++) floats[i] = readFloat(); - } } else { - // Binary format - find @BIN@ marker - while (m_pos < m_buffer.size()) { - if (m_buffer[m_pos] == '@') { - break; - } - m_pos++; - } - - // Check for @BIN@ header - if (m_pos + 5 > m_buffer.size() || - std::memcmp(&m_buffer[m_pos], "@BIN@", 5) != 0) { + // Binary: consume the "@BIN@" marker, then read count*4 raw bytes. + char marker[5]; + readExact(reinterpret_cast(marker), 5, name); + if(std::memcmp(marker, "@BIN@", 5) != 0) throw std::runtime_error(name + ": expected @BIN@ marker for binary float block"); - } - m_pos += 5; - - // Read binary floats (little-endian) - size_t num_bytes = count * 4; - if (m_pos + num_bytes > m_buffer.size()) { - throw std::runtime_error(name + ": not enough bytes for " + std::to_string(count) + " floats"); - } - // Copy as little-endian float32 - std::memcpy(floats.data(), &m_buffer[m_pos], num_bytes); - m_pos += num_bytes; + readExact(reinterpret_cast(floats.data()), count * 4, name); } // Reject NaN/Inf weights: corrupted or otherwise invalid models would @@ -315,6 +289,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) { layer.activation_type = ActivationType::ReLU; } else if (activation_str == "ACTIVATION_MISH") { layer.activation_type = ActivationType::Mish; + } else if (activation_str == "ACTIVATION_SILU") { + layer.activation_type = ActivationType::Silu; } else { throw std::runtime_error("Unknown activation type: " + activation_str); } @@ -420,6 +396,98 @@ static void checkBlockChannels(const std::string& block_name, const std::string& } } +TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() { + TransformerRMSNormDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1"); + } + layer.weight = readFloats(layer.num_channels, layer.name + "/weight"); + return layer; +} + +RMSNormLayerDesc KataGoParser::parseRMSNormLayer() { + RMSNormLayerDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + layer.spatial = (readInt() != 0); + layer.cgroup_size = readInt(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1"); + } + if (layer.cgroup_size != 0) { + throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported"); + } + layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma"); + layer.beta = readFloats(layer.num_channels, layer.name + "/beta"); + return layer; +} + +TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version) { + TransformerAttentionBlockDesc block; + block.name = readString(); + block.num_heads = readInt(); + block.num_kv_heads = readInt(); + block.q_head_dim = readInt(); + block.v_head_dim = readInt(); + block.use_rope = (readInt() != 0); + block.learnable_rope = (readInt() != 0); + + if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) { + throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads"); + } + if (block.use_rope && (block.q_head_dim % 2 != 0)) { + throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used"); + } + + block.pre_ln = parseTransformerRMSNorm(); + block.q_proj = parseMatMulLayer(); + block.k_proj = parseMatMulLayer(); + block.v_proj = parseMatMulLayer(); + block.out_proj = parseMatMulLayer(); + + if (block.use_rope) { + if (block.learnable_rope) { + readString(); // ropeFreqs name + block.rope_num_kv_heads = readInt(); + block.rope_num_pairs = readInt(); + int rope_dim2 = readInt(); + if (block.rope_num_kv_heads != block.num_kv_heads || + block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) { + throw std::runtime_error(block.name + ": invalid learnable rope header"); + } + block.rope_freqs = readFloats( + static_cast(block.rope_num_kv_heads) * block.rope_num_pairs * 2, + block.name + "/rope_freqs"); + } else { + readString(); // ropeTheta name + block.rope_theta = readFloat(); + } + } + return block; +} + +TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version) { + TransformerFFNBlockDesc block; + block.name = readString(); + block.num_channels = readInt(); + block.ffn_channels = readInt(); + block.use_swiglu = (readInt() != 0); + if (block.num_channels < 1 || block.ffn_channels < 1) { + throw std::runtime_error(block.name + ": transformer ffn channels must be positive"); + } + block.pre_ln = parseTransformerRMSNorm(); + block.linear1 = parseMatMulLayer(); + if (block.use_swiglu) { + block.linear_gate = parseMatMulLayer(); + } + block.linear2 = parseMatMulLayer(); + return block; +} + std::vector KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) { std::vector blocks; blocks.reserve(num_blocks); @@ -449,6 +517,14 @@ std::vector KataGoParser::parseBlockStack(int model_version, int num desc.pre_bn.num_channels, desc.post_conv.out_channels, trunk_num_channels); entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_attention_block") { + entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND; + auto desc = parseTransformerAttentionBlock(model_version); + entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_ffn_block") { + entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND; + auto desc = parseTransformerFFNBlock(model_version); + entry.block = std::make_shared(std::move(desc)); } else { throw std::runtime_error("Unknown block kind: " + block_kind_name); } @@ -506,15 +582,15 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) } // Version >= 15 writes the trunk norm kind followed by 5 unused int parameters. - // This CoreML parser only supports the standard trunk norm kind (0 = BatchNorm/BiasMask); - // RMSNorm (used by transformer/rmsnorm models, kind != 0) is not implemented here, so reject it - // defensively rather than silently parsing it as standard norm and producing wrong outputs. + // Unlike upstream's CoreML parser (which rejects any non-standard norm), this fork + // implements RMSNorm, so we capture the kind here instead of throwing. The 5 trailing + // ints are reserved and still expected to be zero. if (model_version >= 15) { - int trunk_norm_kind = readInt(); - if (trunk_norm_kind != 0) { - throw std::runtime_error(trunk.name + ": unsupported trunk norm kind " + - std::to_string(trunk_norm_kind) + - " (this CoreML parser only supports standard trunk norm, not RMSNorm)"); + trunk.trunk_norm_kind = readInt(); + if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD && + trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) { + throw std::runtime_error(trunk.name + ": unknown/unsupported trunk norm kind " + + std::to_string(trunk.trunk_norm_kind)); } for (int i = 0; i < 5; i++) { int unused = readInt(); @@ -561,14 +637,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) // Parse residual blocks trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels); - trunk.trunk_tip_bn = parseBatchNormLayer(); - trunk.trunk_tip_activation = parseActivationLayer(model_version); - if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { - throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + - std::to_string(trunk.trunk_tip_bn.num_channels) + - ") != trunkNumChannels (" + - std::to_string(trunk.trunk_num_channels) + ")"); + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk.trunk_tip_bn = parseBatchNormLayer(); + if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + + std::to_string(trunk.trunk_tip_bn.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } + } else { + trunk.trunk_tip_rms_norm = parseRMSNormLayer(); + if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" + + std::to_string(trunk.trunk_tip_rms_norm.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } } + trunk.trunk_tip_activation = parseActivationLayer(model_version); return trunk; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index a7d9f161c..201bd6b98 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -5,8 +5,11 @@ #include "../types/KataGoTypes.hpp" #include +#include #include +#include #include +#include namespace katagocoreml { @@ -31,9 +34,23 @@ class KataGoParser { private: std::string m_model_path; - std::vector m_buffer; - size_t m_pos = 0; + // Custom-deleter unique_ptr owns the gzFile so it closes on every exit path + // (normal return, exception, or bad_alloc) without manual try/catch. + struct GzCloser { + void operator()(gzFile f) const noexcept { if(f) gzclose(f); } + }; + using GzHandle = std::unique_ptr, GzCloser>; + GzHandle m_gz; + std::vector m_refill; // bounded refill buffer (~1 MB) + size_t m_refillPos = 0; // read cursor within m_refill + size_t m_refillLen = 0; // valid bytes in m_refill bool m_binary_floats = true; + bool m_formatDetected = false; + + // Stream primitives + bool refill(); // returns false at EOF + int peekByte(); // -1 at EOF + void readExact(uint8_t* dst, size_t n, const std::string& name); // Low-level reading functions void readUntilWhitespace(std::string& out); @@ -50,11 +67,15 @@ class KataGoParser { ActivationLayerDesc parseActivationLayer(int model_version); MatMulLayerDesc parseMatMulLayer(); MatBiasLayerDesc parseMatBiasLayer(); + TransformerRMSNormDesc parseTransformerRMSNorm(); + RMSNormLayerDesc parseRMSNormLayer(); // Block parsing functions ResidualBlockDesc parseResidualBlock(int model_version); GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version); NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels); + TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version); + TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version); std::vector parseBlockStack(int model_version, int num_blocks, int trunk_num_channels); // Component parsing functions @@ -65,9 +86,6 @@ class KataGoParser { // Main model parsing KataGoModelDesc parseModel(); - - // Helper to load file (handles gzip) - void loadFile(); }; } // namespace katagocoreml diff --git a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp index f271f5526..50df8003f 100644 --- a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace katagocoreml { @@ -230,8 +232,13 @@ void CoreMLSerializer::createPackage(const std::string& output_path, if (!out) { throw std::runtime_error("Failed to create temp model file"); } - if (!model->SerializeToOstream(&out)) { - throw std::runtime_error("Failed to serialize model spec"); + { + google::protobuf::io::OstreamOutputStream zos(&out); + google::protobuf::io::CodedOutputStream cos(&zos); + cos.SetSerializationDeterministic(true); + if (!model->SerializeToCodedStream(&cos)) { + throw std::runtime_error("Failed to serialize model spec"); + } } } diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..69d590609 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,20 +15,25 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { - if (use_fp16) { + const size_t count = entry.data.size(); + // Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not + // declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so + // stored bytes stay consistent with each const's declared dtype. + const bool store_fp16 = use_fp16 && !entry.is_fp32; + if (store_fp16) { // Convert FP32 weights to FP16 - std::vector fp16_data(entry.data.size()); - for (size_t i = 0; i < entry.data.size(); ++i) { + std::vector fp16_data(count); + for (size_t i = 0; i < count; ++i) { fp16_data[i] = MILBlob::Fp16::FromFloat(entry.data[i]); } MILBlob::Util::Span span(fp16_data.data(), fp16_data.size()); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(MILBlob::Fp16); + total_bytes += count * sizeof(MILBlob::Fp16); } else { - // Write FP32 weights - MILBlob::Util::Span span(entry.data.data(), entry.data.size()); + // Write FP32 weights — convert the KataGo-local view to a MILBlob span here. + MILBlob::Util::Span span(entry.data.data(), count); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(float); + total_bytes += count * sizeof(float); } } diff --git a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp index 147541a39..1074ad419 100644 --- a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp +++ b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp @@ -20,10 +20,15 @@ namespace katagocoreml { enum class ActivationType : int { Identity = 0, ReLU = 1, - Mish = 2 + Mish = 2, + Silu = 3 // MISH_SCALE8 = 12 is internal optimization, treated as Mish }; +/// Trunk normalization kind (matching KataGo's desc.h) +constexpr int TRUNK_NORM_KIND_STANDARD = 0; +constexpr int TRUNK_NORM_KIND_RMSNORM = 1; + // ============================================================================ // Block Kind Constants // ============================================================================ @@ -32,6 +37,8 @@ enum class ActivationType : int { constexpr int ORDINARY_BLOCK_KIND = 0; constexpr int GLOBAL_POOLING_BLOCK_KIND = 2; constexpr int NESTED_BOTTLENECK_BLOCK_KIND = 3; +constexpr int TRANSFORMER_ATTENTION_BLOCK_KIND = 4; +constexpr int TRANSFORMER_FFN_BLOCK_KIND = 5; // ============================================================================ // Layer Descriptors @@ -99,6 +106,25 @@ struct MatBiasLayerDesc { std::vector weights; // Shape: [num_channels] }; +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +struct TransformerRMSNormDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + std::vector weight; // Shape: [num_channels] +}; + +/// Full-featured RMSNorm (gamma/beta, spatial mode) used at the trunk tip. +struct RMSNormLayerDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + bool spatial = false; + int cgroup_size = 0; + std::vector gamma; // Shape: [num_channels] + std::vector beta; // Shape: [num_channels] +}; + // ============================================================================ // Block Descriptors // ============================================================================ @@ -107,12 +133,16 @@ struct MatBiasLayerDesc { struct ResidualBlockDesc; struct GlobalPoolingResidualBlockDesc; struct NestedBottleneckResidualBlockDesc; +struct TransformerAttentionBlockDesc; +struct TransformerFFNBlockDesc; /// Block descriptor variant using BlockDesc = std::variant< ResidualBlockDesc, GlobalPoolingResidualBlockDesc, - NestedBottleneckResidualBlockDesc + NestedBottleneckResidualBlockDesc, + TransformerAttentionBlockDesc, + TransformerFFNBlockDesc >; /// Block with its kind @@ -166,6 +196,38 @@ struct NestedBottleneckResidualBlockDesc { ConvLayerDesc post_conv; }; +/// Transformer self-attention block descriptor (pre-norm, multi-head, optional 2D RoPE, GQA). +struct TransformerAttentionBlockDesc { + std::string name; + int num_heads = 0; + int num_kv_heads = 0; + int q_head_dim = 0; + int v_head_dim = 0; + bool use_rope = false; + bool learnable_rope = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc q_proj; + MatMulLayerDesc k_proj; + MatMulLayerDesc v_proj; + MatMulLayerDesc out_proj; + int rope_num_kv_heads = 0; + int rope_num_pairs = 0; + std::vector rope_freqs; // learnable: (num_kv_heads, num_pairs, 2) flattened + float rope_theta = 0.0f; +}; + +/// Transformer feed-forward (SwiGLU) block descriptor. +struct TransformerFFNBlockDesc { + std::string name; + int num_channels = 0; + int ffn_channels = 0; + bool use_swiglu = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc linear1; + MatMulLayerDesc linear_gate; // only used when use_swiglu + MatMulLayerDesc linear2; +}; + // ============================================================================ // SGF Metadata Encoder (v15+) // ============================================================================ @@ -203,7 +265,9 @@ struct TrunkDesc { MatMulLayerDesc initial_matmul; std::optional sgf_metadata_encoder; std::vector blocks; + int trunk_norm_kind = TRUNK_NORM_KIND_STANDARD; BatchNormLayerDesc trunk_tip_bn; + RMSNormLayerDesc trunk_tip_rms_norm; ActivationLayerDesc trunk_tip_activation; }; diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 8141b4366..72e01238d 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -197,6 +197,10 @@ void ConvLayerDesc::scaleOutputChannels(const std::vector& scaling) { } } +void ConvLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- BatchNormLayerDesc::BatchNormLayerDesc() : numChannels(0), epsilon(0.001f), hasScale(false), hasBias(false) {} @@ -389,6 +393,15 @@ ActivationLayerDesc::ActivationLayerDesc(istream& in, int modelVersion) { } } +void BatchNormLayerDesc::releaseWeights() { + std::vector().swap(mean); + std::vector().swap(variance); + std::vector().swap(scale); + std::vector().swap(bias); + std::vector().swap(mergedScale); + std::vector().swap(mergedBias); +} + ActivationLayerDesc::ActivationLayerDesc(ActivationLayerDesc&& other) { *this = std::move(other); } @@ -517,6 +530,10 @@ MatBiasLayerDesc::MatBiasLayerDesc(istream& in, bool binaryFloats) { throw StringError(name + ": matbiaslayer failed to parse expected number of matbias weights"); } +void MatMulLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + MatBiasLayerDesc::MatBiasLayerDesc(MatBiasLayerDesc&& other) { *this = std::move(other); } @@ -538,6 +555,10 @@ void MatBiasLayerDesc::applyScale8ToReduceActivations() { } } +void MatBiasLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- ResidualBlockDesc::ResidualBlockDesc() {} @@ -617,6 +638,13 @@ void ResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void ResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- GlobalPoolingResidualBlockDesc::GlobalPoolingResidualBlockDesc() {} @@ -738,6 +766,16 @@ void GlobalPoolingResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void GlobalPoolingResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + gpoolConv.releaseWeights(); + gpoolBN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- NestedBottleneckResidualBlockDesc::NestedBottleneckResidualBlockDesc() {} @@ -992,6 +1030,38 @@ void NestedBottleneckResidualBlockDesc::applyScale8ToReduceActivations() { postActivation.applyScale8ToReduceActivations(); } +void NestedBottleneckResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + preConv.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_ATTENTION_BLOCK_KIND) { + TransformerAttentionDesc* desc = (TransformerAttentionDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_FFN_BLOCK_KIND) { + TransformerFFNDesc* desc = (TransformerFFNDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + postBN.releaseWeights(); + postConv.releaseWeights(); +} + //----------------------------------------------------------------------------- RMSNormLayerDesc::RMSNormLayerDesc() : numChannels(0), epsilon(0), spatial(false), cgroupSize(0) {} @@ -1043,6 +1113,11 @@ int64_t RMSNormLayerDesc::getNumParameters() const { return (int64_t)gamma.size() + (int64_t)beta.size(); } +void RMSNormLayerDesc::releaseWeights() { + std::vector().swap(gamma); + std::vector().swap(beta); +} + //----------------------------------------------------------------------------- TransformerRMSNormDesc::TransformerRMSNormDesc() : numChannels(0), epsilon(0) {} @@ -1083,6 +1158,10 @@ int64_t TransformerRMSNormDesc::getNumParameters() const { return (int64_t)weight.size(); } +void TransformerRMSNormDesc::releaseWeights() { + std::vector().swap(weight); +} + //----------------------------------------------------------------------------- TransformerAttentionDesc::TransformerAttentionDesc() @@ -1209,6 +1288,15 @@ int64_t TransformerAttentionDesc::getNumParameters() const { (int64_t)ropeFreqs.size(); // learnable RoPE frequencies, empty for fixed/no RoPE } +void TransformerAttentionDesc::releaseWeights() { + preLN.releaseWeights(); + qProj.releaseWeights(); + kProj.releaseWeights(); + vProj.releaseWeights(); + outProj.releaseWeights(); + std::vector().swap(ropeFreqs); +} + void TransformerAttentionDesc::computeRopeCosSin(int nnXLen, int nnYLen, int paddedNNXYLen, std::vector& cosTable, std::vector& sinTable) const { if(!useRope) throw StringError("TransformerAttentionDesc::computeRopeCosSin called but useRope is false"); @@ -1344,6 +1432,13 @@ int64_t TransformerFFNDesc::getNumParameters() const { linear2.getNumParameters(); } +void TransformerFFNDesc::releaseWeights() { + preLN.releaseWeights(); + linear1.releaseWeights(); + linearGate.releaseWeights(); + linear2.releaseWeights(); +} + //----------------------------------------------------------------------------- static void parseResidualBlockStack( @@ -1550,6 +1645,14 @@ int64_t SGFMetadataEncoderDesc::getNumParameters() const { mul3.getNumParameters(); } +void SGFMetadataEncoderDesc::releaseWeights() { + mul1.releaseWeights(); + bias1.releaseWeights(); + mul2.releaseWeights(); + bias2.releaseWeights(); + mul3.releaseWeights(); +} + //----------------------------------------------------------------------------- TrunkDesc::TrunkDesc() @@ -1906,6 +2009,40 @@ void TrunkDesc::applyScale8ToReduceActivations() { } } +void TrunkDesc::releaseWeights() { + initialConv.releaseWeights(); + initialMatMul.releaseWeights(); + if(metaEncoderVersion > 0) + sgfMetadataEncoder.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_ATTENTION_BLOCK_KIND) { + TransformerAttentionDesc* desc = (TransformerAttentionDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_FFN_BLOCK_KIND) { + TransformerFFNDesc* desc = (TransformerFFNDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + // Whichever trunk tip norm is unused has empty parameter vectors, so releasing both is safe. + trunkTipBN.releaseWeights(); + trunkTipRMSNorm.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -2086,6 +2223,18 @@ void PolicyHeadDesc::applyScale8ToReduceActivations() { passActivation.applyScale8ToReduceActivations(); } +void PolicyHeadDesc::releaseWeights() { + p1Conv.releaseWeights(); + g1Conv.releaseWeights(); + g1BN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + p1BN.releaseWeights(); + p2Conv.releaseWeights(); + gpoolToPassMul.releaseWeights(); + gpoolToPassBias.releaseWeights(); + gpoolToPassMul2.releaseWeights(); +} + //----------------------------------------------------------------------------- ValueHeadDesc::ValueHeadDesc() : modelVersion(-1) {} @@ -2246,6 +2395,17 @@ void ValueHeadDesc::applyScale8ToReduceActivations() { sv3Bias.applyScale8ToReduceActivations(); } +void ValueHeadDesc::releaseWeights() { + v1Conv.releaseWeights(); + v1BN.releaseWeights(); + v2Mul.releaseWeights(); + v2Bias.releaseWeights(); + v3Mul.releaseWeights(); + v3Bias.releaseWeights(); + sv3Mul.releaseWeights(); + sv3Bias.releaseWeights(); + vOwnershipConv.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -2562,6 +2722,12 @@ void ModelDesc::applyScale8ToReduceActivations() { postProcessParams.outputScaleMultiplier *= 8.0f; } +void ModelDesc::releaseWeights() { + trunk.releaseWeights(); + policyHead.releaseWeights(); + valueHead.releaseWeights(); +} + struct NonCopyingStreamBuf : public std::streambuf { NonCopyingStreamBuf(string& str) { diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index 36c5a11d8..ef41dfca6 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -36,6 +36,8 @@ struct ConvLayerDesc { int64_t getNumParameters() const; void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct BatchNormLayerDesc { @@ -68,6 +70,8 @@ struct BatchNormLayerDesc { void extractChannelFactorsAbsLtOne(std::vector& channelFactors); void extractChannelFactorsAbsLtOneWithInverses(std::vector& channelFactors, std::vector& invChannelFactors); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ActivationLayerDesc { @@ -105,6 +109,8 @@ struct MatMulLayerDesc { int64_t getNumParameters() const; void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct MatBiasLayerDesc { @@ -124,6 +130,8 @@ struct MatBiasLayerDesc { int64_t getNumParameters() const; void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ResidualBlockDesc { @@ -150,6 +158,8 @@ struct ResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct GlobalPoolingResidualBlockDesc { @@ -181,6 +191,8 @@ struct GlobalPoolingResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct NestedBottleneckResidualBlockDesc { @@ -215,6 +227,8 @@ struct NestedBottleneckResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; // Trunk final normalization kind (stored in trunk header) @@ -240,6 +254,7 @@ struct RMSNormLayerDesc { RMSNormLayerDesc& operator=(RMSNormLayerDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; // Lightweight RMSNorm used inside transformer blocks (weight only, no bias, no spatial modes) @@ -259,6 +274,7 @@ struct TransformerRMSNormDesc { TransformerRMSNormDesc& operator=(TransformerRMSNormDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; struct TransformerAttentionDesc { @@ -294,6 +310,7 @@ struct TransformerAttentionDesc { TransformerAttentionDesc& operator=(TransformerAttentionDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); // Compute cos/sin tables for RoPE given board dimensions. // Output tables are indexed as: @@ -324,6 +341,7 @@ struct TransformerFFNDesc { TransformerFFNDesc& operator=(TransformerFFNDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; struct SGFMetadataEncoderDesc { @@ -349,6 +367,7 @@ struct SGFMetadataEncoderDesc { SGFMetadataEncoderDesc& operator=(SGFMetadataEncoderDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; @@ -397,6 +416,8 @@ struct TrunkDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct PolicyHeadDesc { @@ -431,6 +452,8 @@ struct PolicyHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ValueHeadDesc { @@ -463,6 +486,8 @@ struct ValueHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ModelPostProcessParams { @@ -534,6 +559,11 @@ struct ModelDesc { //Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made. Rules getSupportedRules(const Rules& desiredRules, bool& supported) const; + // Frees all weight arrays (conv/matmul/bias/batchnorm), keeping scalar shape + // metadata intact. Safe once weights are no longer needed (e.g. CoreML/ANE + // inference, which reads weights from the compiled .mlmodelc). + void releaseWeights(); + }; #endif // #ifndef DESC_H diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 786ef8290..172e337d4 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -130,6 +130,8 @@ ActivationKind activationLayerDescToSwift(const ActivationLayerDesc* desc) { return ActivationKind::mish(); case ACTIVATION_MISH_SCALE8: return ActivationKind::identity(); // Metal/CoreML does not use scaled mish + case ACTIVATION_SILU: + return ActivationKind::silu(); case ACTIVATION_IDENTITY: return ActivationKind::identity(); default: @@ -217,6 +219,63 @@ SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(con postConv); } +/// Convert a transformer RMSNorm description from C++ to Swift +SWTransformerRMSNormDesc transformerRMSNormDescToSwift(const TransformerRMSNormDesc* desc) { + return createSWTransformerRMSNormDesc( + desc->numChannels, + desc->epsilon, + (float*)desc->weight.data()); +} + +/// Convert a transformer attention block description from C++ to Swift +SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const TransformerAttentionDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc qProj = matMulLayerDescToSwift(&desc->qProj); + SWMatMulLayerDesc kProj = matMulLayerDescToSwift(&desc->kProj); + SWMatMulLayerDesc vProj = matMulLayerDescToSwift(&desc->vProj); + SWMatMulLayerDesc outProj = matMulLayerDescToSwift(&desc->outProj); + float* ropeFreqs = desc->ropeFreqs.empty() ? nullptr : (float*)desc->ropeFreqs.data(); + + return createSWTransformerAttentionBlockDesc( + desc->numHeads, + desc->numKVHeads, + desc->qHeadDim, + desc->vHeadDim, + desc->useRope, + desc->learnableRope, + preLN, + qProj, + kProj, + vProj, + outProj, + desc->ropeNumKVHeads, + desc->ropeNumPairs, + ropeFreqs, + desc->ropeTheta); +} + +/// Convert a transformer FFN block description from C++ to Swift +SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + // The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path + // (SiLU(linear1) * gate); a non-SwiGLU model has no gate weights, so guard here as Eigen and CoreML + // do (eigenbackend.cpp / katagocoreml MILBuilder) instead of crashing on the empty gate descriptor. + if(!desc->useSwiGLU) + throw StringError(desc->name + ": non-SwiGLU transformer FFN not supported in Metal backend"); + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); + SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); + SWMatMulLayerDesc linear2 = matMulLayerDescToSwift(&desc->linear2); + + return createSWTransformerFFNBlockDesc( + desc->numChannels, + desc->ffnChannels, + desc->useSwiGLU, + preLN, + linear1, + linearGate, + linear2); +} + /// Convert residual blocks from C++ to Swift swift::Array residualBlocksToSwift(const vector>& blocks) { auto builder = createBlockDescriptorBuilder(); @@ -230,9 +289,12 @@ swift::Array residualBlocksToSwift(const vector sGFMetadataEncoderDescToSwift(const SG } /// Convert a trunk description from C++ to Swift +SWRMSNormLayerDesc rmsNormLayerDescToSwift(const RMSNormLayerDesc* desc) { + float* gamma = desc->gamma.empty() ? nullptr : (float*)desc->gamma.data(); + float* beta = desc->beta.empty() ? nullptr : (float*)desc->beta.data(); + return createSWRMSNormLayerDesc( + desc->numChannels, + desc->epsilon, + desc->spatial, + gamma, + beta); +} + SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { SWConvLayerDesc initialConv = convLayerDescToSwift(&trunk->initialConv); SWMatMulLayerDesc initialMatMul = matMulLayerDescToSwift(&trunk->initialMatMul); auto sgfMetadataEncoder = sGFMetadataEncoderDescToSwift(&trunk->sgfMetadataEncoder); auto swBlocks = residualBlocksToSwift(trunk->blocks); - if(trunk->trunkNormKind != TRUNK_NORM_KIND_STANDARD) - throw StringError("Trunk RMSNorm is not yet supported by the Metal backend"); SWBatchNormLayerDesc trunkTipBN = batchNormLayerDescToSwift(&trunk->trunkTipBN); + SWRMSNormLayerDesc trunkTipRMSNorm = rmsNormLayerDescToSwift(&trunk->trunkTipRMSNorm); ActivationKind trunkTipActivation = activationLayerDescToSwift(&trunk->trunkTipActivation); return createSWTrunkDesc( @@ -285,7 +357,9 @@ SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { initialMatMul, sgfMetadataEncoder, swBlocks, + trunk->trunkNormKind, trunkTipBN, + trunkTipRMSNorm, trunkTipActivation); } @@ -426,13 +500,27 @@ ComputeContext* NeuralNet::createComputeContext( const LoadedModel* loadedModel, ConfigParser& cfg) { - (void)gpuIdxs; + // Only ANE-only configurations may free the engine's in-memory weights: the + // GPU/MPSGraph path reads them via modelDescToSwift, so freeing is unsafe + // unless no GPU handle can ever be built from this model. + // INVARIANT: gpuIdxs must be the complete (deduplicated) set of device indices + // that will ever be passed as gpuIdxForThisThread to createComputeHandle for + // this context. aneOnly==true frees the in-memory weights, so if any thread + // later used a GPU (MPSGraph) index not represented here, it would read freed + // weights. KataGo derives both from the same gpuIdxByServerThread list, so the + // invariant holds today; preserve it if that wiring ever changes. + bool aneOnly = !gpuIdxs.empty(); + for(int idx : gpuIdxs) { + if(idx != METAL_MUX_ANE) { aneOnly = false; break; } + } (void)logger; (void)homeDataDirOverride; (void)loadedModel; (void)cfg; - return new ComputeContext(nnXLen, nnYLen, useFP16Mode); + ComputeContext* context = new ComputeContext(nnXLen, nnYLen, useFP16Mode); + context->aneOnly = aneOnly; + return context; } void NeuralNet::freeComputeContext(ComputeContext* computeContext) { @@ -459,6 +547,17 @@ static swift::Optional convertAndCreateCoreMLO bool useFP16 = (context->useFP16Mode != enabled_t::False); bool optimizeMask = requireExactNNLen; + // On a confirmed ANE-only run, free the engine's in-memory ModelDesc weight + // arrays. This function converts from loadedModel->modelPath (disk), + // so the in-memory weights are not read here; the GPU/MPSGraph path (which + // DOES read them via modelDescToSwift) is never built when aneOnly is true. + // The whole ComputeHandle ctor runs under computeHandleMutex, so this is not + // racy; releaseWeights() clears only weight vectors, leaving the scalar dims + // read by the ComputeHandle ctor / InputBuffers valid. + if(context->aneOnly) { + const_cast(loadedModel)->modelDesc.releaseWeights(); + } + // Convert model to CoreML format in temp directory string coremlModelPath = CoreMLConversion::convertModelToTemp( loadedModel->modelPath, diff --git a/cpp/neuralnet/metalbackend.h b/cpp/neuralnet/metalbackend.h index b7f751e63..e77dd18d9 100644 --- a/cpp/neuralnet/metalbackend.h +++ b/cpp/neuralnet/metalbackend.h @@ -113,6 +113,14 @@ struct ComputeContext { */ MetalComputeContext metalContext; + /** + * @brief True only when every configured device is METAL_MUX_ANE, so no + * MPSGraph (GPU) handle will ever read modelDesc weights. Gates the call to + * ModelDesc::releaseWeights() so a mixed GPU+ANE config can never free live + * weights. + */ + bool aneOnly = false; + /** * @brief Constructs a ComputeContext object. * @param nnX The width of the input tensor. @@ -179,6 +187,18 @@ struct ComputeHandle { */ bool maskIdentityChecked = false; + // Weight-release safety is guaranteed by ComputeContext::aneOnly, NOT by the + // declaration order below: within a single ComputeHandle exactly one handle is + // built (the two paths are mutually exclusive on gpuIdx, enforced by the + // ctor's exactly-one check), and releaseWeights() only ever fires on an + // aneOnly context, where no MPSGraph handle is built for any thread. + // That said, keep mpsGraphOnlyHandle declared before coremlOnlyHandle. C++ + // initializes members in DECLARATION order, so createMPSGraphHandleIfNeeded + // (which reads modelDesc weights via modelDescToSwift) is sequenced before + // createCoreMLOnlyHandleIfNeeded (which may call modelDesc.releaseWeights()). + // This ordering is belt-and-suspenders that preserves the natural read-then- + // release sequence should the aneOnly invariant ever be weakened; don't rely + // on it as the primary guarantee, but don't reorder it either. /** * @brief The MPSGraph-only handle instance from Swift (GPU-only mode). */ diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index bbd2255bc..e1324df96 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -76,6 +76,11 @@ extension MPSGraph { return mulTensor } + + /// SiLU / Swish activation: x * sigmoid(x). Numerically stable across FP16/FP32. + func silu(tensor: MPSGraphTensor) -> MPSGraphTensor { + return multiplication(tensor, sigmoid(with: tensor, name: nil), name: nil) + } } // MARK: - Input Shape Utilities @@ -358,6 +363,7 @@ public enum ActivationKind { case identity case relu case mish + case silu } /// A struct that represents a description of convolutional layer. @@ -487,6 +493,63 @@ public func createSWMatBiasLayerDesc( weights: weights) } +/// A lightweight RMSNorm description used inside transformer blocks (weight only, no bias). +public struct SWTransformerRMSNormDesc { + let numChannels: NSNumber + let epsilon: Float + let weight: UnsafeMutablePointer + + init(numChannels: NSNumber, epsilon: Float, weight: UnsafeMutablePointer) { + self.numChannels = numChannels + self.epsilon = epsilon + self.weight = weight + } +} + +public func createSWTransformerRMSNormDesc( + numChannels: Int32, + epsilon: Float, + weight: UnsafeMutablePointer +) -> SWTransformerRMSNormDesc { + return SWTransformerRMSNormDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + weight: weight) +} + +/// A full-featured RMSNorm description (gamma/beta, spatial mode), used at the trunk tip. +public struct SWRMSNormLayerDesc { + let numChannels: NSNumber + let epsilon: Float + let spatial: Bool + let gamma: UnsafeMutablePointer? + let beta: UnsafeMutablePointer? + + init(numChannels: NSNumber, epsilon: Float, spatial: Bool, + gamma: UnsafeMutablePointer?, beta: UnsafeMutablePointer?) { + self.numChannels = numChannels + self.epsilon = epsilon + self.spatial = spatial + self.gamma = gamma + self.beta = beta + } +} + +public func createSWRMSNormLayerDesc( + numChannels: Int32, + epsilon: Float, + spatial: Bool, + gamma: UnsafeMutablePointer?, + beta: UnsafeMutablePointer? +) -> SWRMSNormLayerDesc { + return SWRMSNormLayerDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + spatial: spatial, + gamma: gamma, + beta: beta) +} + // MARK: - Core Layers /// A class that represents a convolutional layer using MPSGraph @@ -612,6 +675,8 @@ struct ActivationLayer { resultTensor = graph.reLU(with: sourceTensor, name: nil) case .mish: resultTensor = graph.mish(tensor: sourceTensor) + case .silu: + resultTensor = graph.silu(tensor: sourceTensor) default: resultTensor = sourceTensor } @@ -987,6 +1052,140 @@ public func createSWNestedBottleneckResidualBlockDesc( postConv: postConv) } +public class SWTransformerAttentionBlockDesc: BlockDescriptor { + let numHeads: Int + let numKVHeads: Int + let qHeadDim: Int + let vHeadDim: Int + let useRope: Bool + let learnableRope: Bool + let preLN: SWTransformerRMSNormDesc + let qProj: SWMatMulLayerDesc + let kProj: SWMatMulLayerDesc + let vProj: SWMatMulLayerDesc + let outProj: SWMatMulLayerDesc + let ropeNumKVHeads: Int + let ropeNumPairs: Int + let ropeFreqs: UnsafeMutablePointer? // learnable: (numKVHeads, numPairs, 2) flattened + let ropeTheta: Float + + init( + numHeads: Int, + numKVHeads: Int, + qHeadDim: Int, + vHeadDim: Int, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int, + ropeNumPairs: Int, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float + ) { + self.numHeads = numHeads + self.numKVHeads = numKVHeads + self.qHeadDim = qHeadDim + self.vHeadDim = vHeadDim + self.useRope = useRope + self.learnableRope = learnableRope + self.preLN = preLN + self.qProj = qProj + self.kProj = kProj + self.vProj = vProj + self.outProj = outProj + self.ropeNumKVHeads = ropeNumKVHeads + self.ropeNumPairs = ropeNumPairs + self.ropeFreqs = ropeFreqs + self.ropeTheta = ropeTheta + } +} + +public func createSWTransformerAttentionBlockDesc( + numHeads: Int32, + numKVHeads: Int32, + qHeadDim: Int32, + vHeadDim: Int32, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int32, + ropeNumPairs: Int32, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float +) -> SWTransformerAttentionBlockDesc { + return SWTransformerAttentionBlockDesc( + numHeads: Int(numHeads), + numKVHeads: Int(numKVHeads), + qHeadDim: Int(qHeadDim), + vHeadDim: Int(vHeadDim), + useRope: useRope, + learnableRope: learnableRope, + preLN: preLN, + qProj: qProj, + kProj: kProj, + vProj: vProj, + outProj: outProj, + ropeNumKVHeads: Int(ropeNumKVHeads), + ropeNumPairs: Int(ropeNumPairs), + ropeFreqs: ropeFreqs, + ropeTheta: ropeTheta) +} + +public class SWTransformerFFNBlockDesc: BlockDescriptor { + let numChannels: Int + let ffnChannels: Int + let useSwiGLU: Bool + let preLN: SWTransformerRMSNormDesc + let linear1: SWMatMulLayerDesc + let linearGate: SWMatMulLayerDesc + let linear2: SWMatMulLayerDesc + + init( + numChannels: Int, + ffnChannels: Int, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc + ) { + self.numChannels = numChannels + self.ffnChannels = ffnChannels + self.useSwiGLU = useSwiGLU + self.preLN = preLN + self.linear1 = linear1 + self.linearGate = linearGate + self.linear2 = linear2 + } +} + +public func createSWTransformerFFNBlockDesc( + numChannels: Int32, + ffnChannels: Int32, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc +) -> SWTransformerFFNBlockDesc { + return SWTransformerFFNBlockDesc( + numChannels: Int(numChannels), + ffnChannels: Int(ffnChannels), + useSwiGLU: useSwiGLU, + preLN: preLN, + linear1: linear1, + linearGate: linearGate, + linear2: linear2) +} + public class BlockDescriptorBuilder { public var blockDescriptors: [BlockDescriptor] = [] @@ -1001,6 +1200,329 @@ public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { return BlockDescriptorBuilder() } +// MARK: - Transformer Layers + +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +/// Input/output are NCHW [B, C, H, W]. Normalizes across channels per spatial position, +/// scales by per-channel weight, and masks the output. +struct TransformerRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerRMSNormDesc + ) { + let numChannels = descriptor.numChannels + let dataType = sourceTensor.dataType + + // meanSq over channel axis (1): [B,1,H,W] + let sq = graph.square(with: sourceTensor, name: nil) + let sumSq = graph.reductionSum(with: sq, axis: 1, name: nil) + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + let meanSq = graph.multiplication(sumSq, invC, name: nil) + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + // scale by per-channel weight [1, C, 1, 1] + let weightShape: [NSNumber] = [1, numChannels, 1, 1] + let weightData = Data(floatsNoCopy: descriptor.weight, shape: weightShape) + let weightTensor = graph.constant(weightData, shape: weightShape, dataType: dataType) + let scaled = graph.multiplication(normalized, weightTensor, name: nil) + + resultTensor = graph.multiplication(scaled, maskTensor, name: nil) + } +} + +/// Full-featured RMSNorm for the trunk tip: gamma/beta, spatial or per-position mode, and a +/// fused activation. Input/output are NCHW [B, C, H, W]. Mirrors the Eigen RMSNormLayer. +struct TrunkRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWRMSNormLayerDesc, + activationKind: ActivationKind + ) { + let dataType = sourceTensor.dataType + let numChannels = descriptor.numChannels + + // Zero invalid positions before accumulating sum of squares. + let masked = graph.multiplication(sourceTensor, maskTensor, name: nil) + let sq = graph.square(with: masked, name: nil) + + let meanSq: MPSGraphTensor + if descriptor.spatial { + // Normalize over channels AND valid spatial positions per batch element. + let sumSq = graph.reductionSum(with: sq, axes: [1, 2, 3], name: nil) // [B,1,1,1] + let count = graph.reductionSum(with: maskTensor, axes: [1, 2, 3], name: nil) // valid positions + let cTensor = graph.constant(numChannels.doubleValue, dataType: dataType) + let totalElts = graph.multiplication(count, cTensor, name: nil) + meanSq = graph.division(sumSq, totalElts, name: nil) + } else { + // Per-position normalization across channels. + let sumSq = graph.reductionSum(with: sq, axes: [1], name: nil) // [B,1,H,W] + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + meanSq = graph.multiplication(sumSq, invC, name: nil) + } + + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + let gammaShape: [NSNumber] = [1, numChannels, 1, 1] + let gammaTensor = graph.constant(Data(floatsNoCopy: descriptor.gamma!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let betaTensor = graph.constant(Data(floatsNoCopy: descriptor.beta!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let scaled = graph.addition(graph.multiplication(normalized, gammaTensor, name: nil), betaTensor, name: nil) + + let activated = ActivationLayer(graph: graph, sourceTensor: scaled, activationKind: activationKind).resultTensor + resultTensor = graph.multiplication(activated, maskTensor, name: nil) + } +} + +/// A transformer self-attention block (pre-norm, multi-head, optional 2D RoPE, GQA). +/// Mirrors the Eigen reference: RMSNorm -> Q/K/V projections -> RoPE -> scaled dot-product +/// attention with masked softmax -> output projection -> masked residual. +/// Tensors are NCHW [B, C, H, W]; spatial positions (H*W, ordered y*W+x) are the sequence. +struct TransformerAttentionBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerAttentionBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let dataType = sourceTensor.dataType + let numHeads = descriptor.numHeads + let numKVHeads = descriptor.numKVHeads + let qHeadDim = descriptor.qHeadDim + let vHeadDim = descriptor.vHeadDim + let nnX = nnXLen.intValue + let nnY = nnYLen.intValue + let seq = nnX * nnY + + // 1. RMSNorm (NCHW) + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + + // To NHWC [B,H,W,C] so that reshape [-1, C] groups channels per position. + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. Q/K/V projections via matmul over channels -> [B*seq, heads*dim] + let q = MatMulLayer(graph: graph, descriptor: descriptor.qProj, sourceTensor: normedNHWC).resultTensor + let k = MatMulLayer(graph: graph, descriptor: descriptor.kProj, sourceTensor: normedNHWC).resultTensor + let v = MatMulLayer(graph: graph, descriptor: descriptor.vProj, sourceTensor: normedNHWC).resultTensor + + // 3. reshape to [B, heads, seq, dim] + var qh = TransformerAttentionBlock.toHeads(graph, q, seq: seq, numHeads: numHeads, headDim: qHeadDim) + var kh = TransformerAttentionBlock.toHeads(graph, k, seq: seq, numHeads: numKVHeads, headDim: qHeadDim) + let vh = TransformerAttentionBlock.toHeads(graph, v, seq: seq, numHeads: numKVHeads, headDim: vHeadDim) + + // 4. RoPE on Q and K + if descriptor.useRope { + let numPairs = qHeadDim / 2 + // Q heads map to KV heads via kvh = h * numKVHeads / numHeads (matches Eigen). + let (qCos, qSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in (h * numKVHeads) / numHeads }) + let (kCos, kSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numKVHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in h }) + qh = TransformerAttentionBlock.applyRope(graph, qh, cosT: qCos, sinT: qSin, + numHeads: numHeads, seq: seq, numPairs: numPairs) + kh = TransformerAttentionBlock.applyRope(graph, kh, cosT: kCos, sinT: kSin, + numHeads: numKVHeads, seq: seq, numPairs: numPairs) + } + + // GQA: if numKVHeads < numHeads, repeat KV heads so they align with query heads. + var khExp = kh + var vhExp = vh + if numKVHeads != numHeads { + let groupSize = numHeads / numKVHeads + khExp = TransformerAttentionBlock.repeatKVHeads(graph, kh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: qHeadDim) + vhExp = TransformerAttentionBlock.repeatKVHeads(graph, vh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: vHeadDim) + } + + // 5. scores = scale * Q @ K^T -> [B, heads, seq, seq] + let khT = graph.transpose(khExp, permutation: [0, 1, 3, 2], name: nil) + var scores = graph.matrixMultiplication(primary: qh, secondary: khT, name: nil) + let scale = graph.constant(1.0 / Double(qHeadDim).squareRoot(), dataType: dataType) + scores = graph.multiplication(scores, scale, name: nil) + + // Mask keys: add (maskKey - 1) * BIG so masked key columns get ~ -inf before softmax. + // maskTensor [B,1,H,W] -> [B,1,1,seq] + let maskNHWC = graph.transpose(maskTensor, permutation: [0, 2, 3, 1], name: nil) // [B,H,W,1] + let maskSeq = graph.reshape(maskNHWC, shape: [-1, 1, 1, seq as NSNumber], name: nil) + let one = graph.constant(1.0, dataType: dataType) + let big = graph.constant(1.0e9, dataType: dataType) + let keyBias = graph.multiplication(graph.subtraction(maskSeq, one, name: nil), big, name: nil) + scores = graph.addition(scores, keyBias, name: nil) + + // 6. softmax over key axis (last) + let attn = graph.softMax(with: scores, axis: 3, name: nil) + + // 7. out = attn @ V -> [B, heads, seq, vHeadDim] + let attnOut = graph.matrixMultiplication(primary: attn, secondary: vhExp, name: nil) + + // 8. back to [B*seq, heads*vHeadDim] + let outHeadsLast = graph.transpose(attnOut, permutation: [0, 2, 1, 3], name: nil) // [B,seq,heads,vHeadDim] + let outFlat = graph.reshape(outHeadsLast, shape: [-1, (numHeads * vHeadDim) as NSNumber], name: nil) + + // 9. output projection -> [B*seq, C] + let proj = MatMulLayer(graph: graph, descriptor: descriptor.outProj, sourceTensor: outFlat).resultTensor + + // 10. reshape to NHWC then NCHW + let outChannels = descriptor.outProj.outChannels + let projNHWC = graph.reshape(proj, shape: [-1, nnYLen, nnXLen, outChannels], name: nil) + let projNCHW = graph.transpose(projNHWC, permutation: [0, 3, 1, 2], name: nil) + + // 11. masked residual + let masked = graph.multiplication(projNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } + + /// Reshape [B*seq, numHeads*headDim] -> [B, numHeads, seq, headDim]. + static func toHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, seq: Int, numHeads: Int, headDim: Int) -> MPSGraphTensor { + let reshaped = graph.reshape(x, shape: [-1, seq as NSNumber, numHeads as NSNumber, headDim as NSNumber], name: nil) + return graph.transpose(reshaped, permutation: [0, 2, 1, 3], name: nil) + } + + /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. + static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { + // Repeat each KV head groupSize times consecutively so query head h uses kv = h / groupSize, + // matching the Eigen reference (kvh = h / kvGroupSize). We slice each KV head and concat the + // copies along the head axis. Note: MPSGraph.broadcast(_:shape:) does NOT infer -1, so a + // reshape+broadcast approach with a dynamic batch dim triggers an NDArray INT_MAX assertion; + // slice+concat is shape-safe with no -1 broadcast. + var heads: [MPSGraphTensor] = [] + heads.reserveCapacity(numKVHeads * groupSize) + for kv in 0.. MPSGraphTensor { + let pairsShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 2] + let xPairs = graph.reshape(x, shape: pairsShape, name: nil) + let evenShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber] + let xEven = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 0, length: 1, name: nil), shape: evenShape, name: nil) + let xOdd = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 1, length: 1, name: nil), shape: evenShape, name: nil) + let outEven = graph.subtraction(graph.multiplication(xEven, cosT, name: nil), graph.multiplication(xOdd, sinT, name: nil), name: nil) + let outOdd = graph.addition(graph.multiplication(xEven, sinT, name: nil), graph.multiplication(xOdd, cosT, name: nil), name: nil) + let pairShape5: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 1] + let outEvenE = graph.reshape(outEven, shape: pairShape5, name: nil) + let outOddE = graph.reshape(outOdd, shape: pairShape5, name: nil) + let stacked = graph.concatTensors([outEvenE, outOddE], dimension: 4, name: nil) + return graph.reshape(stacked, shape: [-1, numHeads as NSNumber, seq as NSNumber, (numPairs * 2) as NSNumber], name: nil) + } + + /// Build RoPE cos/sin constant tensors of shape [1, nHeads, seq, numPairs]. + static func makeRopeTables(_ graph: MPSGraph, descriptor: SWTransformerAttentionBlockDesc, + nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, + dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { + let count = nHeads * seq * numPairs + // Managed arrays (freed on return). Unlike the weight constants elsewhere, which point at + // C++-owned descriptor memory and so use floatsNoCopy, these tables have no persistent owner; + // we copy them into the Data below so MPSGraph owns the bytes (avoids a leak / use-after-free). + var cosBuf = [Float32](repeating: 0, count: count) + var sinBuf = [Float32](repeating: 0, count: count) + let numPairsPerDim = numPairs / 2 + let dimHalf = qHeadDim / 2 + for h in 0.. SiLU(linear1)*gate -> linear2 -> masked residual. +struct TransformerFFNBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerFFNBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let numChannels = descriptor.numChannels + + // 1. RMSNorm + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. linear1 + gate, both [B*seq, ffnChannels] + let a = MatMulLayer(graph: graph, descriptor: descriptor.linear1, sourceTensor: normedNHWC).resultTensor + let gate = MatMulLayer(graph: graph, descriptor: descriptor.linearGate, sourceTensor: normedNHWC).resultTensor + + // 3. SwiGLU: SiLU(a) * gate, SiLU(a) = a * sigmoid(a) + let siluA = graph.multiplication(a, graph.sigmoid(with: a, name: nil), name: nil) + let h = graph.multiplication(siluA, gate, name: nil) + + // 4. linear2 -> [B*seq, numChannels] + let out = MatMulLayer(graph: graph, descriptor: descriptor.linear2, sourceTensor: h).resultTensor + + // 5. reshape to NHWC then NCHW, masked residual + let outNHWC = graph.reshape(out, shape: [-1, nnYLen, nnXLen, numChannels as NSNumber], name: nil) + let outNCHW = graph.transpose(outNHWC, permutation: [0, 3, 1, 2], name: nil) + let masked = graph.multiplication(outNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } +} + // MARK: - Block Implementations /// A class that represents a Residual Block layer @@ -1241,6 +1763,26 @@ struct BlockStack { optimizeIdentityMask: optimizeIdentityMask) blockInput = ordinary.resultTensor + case let attnDescriptor as SWTransformerAttentionBlockDesc: + let attn = TransformerAttentionBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: attnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = attn.resultTensor + case let ffnDescriptor as SWTransformerFFNBlockDesc: + let ffn = TransformerFFNBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: ffnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = ffn.resultTensor default: blockInput = sourceTensor } @@ -1472,6 +2014,10 @@ class SGFMetadataEncoder { // MARK: - Trunk +/// Trunk-tip normalization kind, mirroring desc.h TRUNK_NORM_KIND_* (the value is serialized in the model). +let TRUNK_NORM_KIND_STANDARD = 0 // BatchNorm or BiasMask (existing) +let TRUNK_NORM_KIND_RMSNORM = 1 // RMSNorm + /// A class that describes a trunk for a neural network public class SWTrunkDesc { let version: Int @@ -1483,7 +2029,9 @@ public class SWTrunkDesc { let initialMatMul: SWMatMulLayerDesc let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? let blockDescriptors: [BlockDescriptor] + let trunkNormKind: Int let trunkTipBN: SWBatchNormLayerDesc + let trunkTipRMSNorm: SWRMSNormLayerDesc let trunkTipActivation: ActivationKind init( @@ -1496,7 +2044,9 @@ public class SWTrunkDesc { initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) { self.version = version @@ -1508,7 +2058,9 @@ public class SWTrunkDesc { self.initialMatMul = initialMatMul self.sgfMetadataEncoder = sgfMetadataEncoder self.blockDescriptors = blockDescriptors + self.trunkNormKind = trunkNormKind self.trunkTipBN = trunkTipBN + self.trunkTipRMSNorm = trunkTipRMSNorm self.trunkTipActivation = trunkTipActivation } } @@ -1523,7 +2075,9 @@ public func createSWTrunkDesc( initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int32, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) -> SWTrunkDesc { return SWTrunkDesc( @@ -1536,7 +2090,9 @@ public func createSWTrunkDesc( initialMatMul: initialMatMul, sgfMetadataEncoder: sgfMetadataEncoder, blockDescriptors: blockDescriptors, + trunkNormKind: Int(trunkNormKind), trunkTipBN: trunkTipBN, + trunkTipRMSNorm: trunkTipRMSNorm, trunkTipActivation: trunkTipActivation) } @@ -1632,21 +2188,33 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - let trunkTipBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.trunkTipBN, - nnXLen: nnXLen, - nnYLen: nnYLen, - optimizeIdentityMask: optimizeIdentityMask) + // RMSNorm trunk tip uses a fused activation; standard uses BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == TRUNK_NORM_KIND_RMSNORM { + let trunkTipRMSNorm = TrunkRMSNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipRMSNorm, + activationKind: descriptor.trunkTipActivation) - let trunkTipActivation = ActivationLayer( - graph: graph, - sourceTensor: trunkTipBN.resultTensor, - activationKind: descriptor.trunkTipActivation) + resultTensor = trunkTipRMSNorm.resultTensor + } else { + let trunkTipBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipActivation = ActivationLayer( + graph: graph, + sourceTensor: trunkTipBN.resultTensor, + activationKind: descriptor.trunkTipActivation) - resultTensor = trunkTipActivation.resultTensor + resultTensor = trunkTipActivation.resultTensor + } assert(resultTensor.shape?.count == 4) }