From b05f5594ee1db8412efe389de15d416f7d1e442e Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sat, 30 May 2026 17:11:04 +0800 Subject: [PATCH 01/18] Reduce CoreML conversion peak and ANE steady-state memory Cuts memory during the on-device KataGo -> CoreML conversion and while running the ANE/CoreML path, with byte-identical converter output: - The converter's weight tensors become non-owning views into the parsed model instead of owning extra FP32 copies; derived/transposed tensors keep an owned buffer. This drops redundant resident weight copies during conversion. CoreML model serialization is made deterministic (SetSerializationDeterministic) so the output is byte-stable. - The KataGo model parser streams the gzip through a bounded ~1 MB refill buffer instead of decompressing the whole file into memory, while preserving the existing NaN/Inf weight validation. - ModelDesc gains releaseWeights(), which frees the in-memory weight arrays (keeping scalar shape metadata). The Metal backend calls it on the ANE (CoreML) path after converting from the model file on disk, gated by a new ComputeContext::aneOnly flag so it only fires when every configured device is ANE -- the GPU/MPSGraph path keeps its weights. The call is serialized under computeHandleMutex and only scalar dims are read afterward. Measured on b18c384nbt (19x19) over the ANE path: idle steady-state RSS 0.59 GB -> 0.19 GB; peak (load+convert) 0.87 GB -> 0.48 GB. Cross-backend parity vs an Eigen reference is unchanged on both the GPU and ANE paths. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/external/katagocoreml/src/Converter.cpp | 16 +- .../katagocoreml/src/builder/MILBuilder.cpp | 18 +- .../katagocoreml/src/builder/MILBuilder.hpp | 13 +- .../katagocoreml/src/builder/Operations.cpp | 20 ++- .../katagocoreml/src/builder/Operations.hpp | 24 ++- .../katagocoreml/src/parser/KataGoParser.cpp | 168 ++++++++---------- .../katagocoreml/src/parser/KataGoParser.hpp | 16 +- .../src/serializer/CoreMLSerializer.cpp | 11 +- .../src/serializer/WeightSerializer.cpp | 10 +- cpp/neuralnet/desc.cpp | 69 +++++++ cpp/neuralnet/desc.h | 5 + cpp/neuralnet/metalbackend.cpp | 29 ++- cpp/neuralnet/metalbackend.h | 14 ++ 13 files changed, 287 insertions(+), 126 deletions(-) diff --git a/cpp/external/katagocoreml/src/Converter.cpp b/cpp/external/katagocoreml/src/Converter.cpp index cb6ca80d9..72b78e736 100644 --- a/cpp/external/katagocoreml/src/Converter.cpp +++ b/cpp/external/katagocoreml/src/Converter.cpp @@ -29,9 +29,12 @@ void KataGoConverter::convert(const std::string& input_path, throw std::invalid_argument("max_batch_size must be >= min_batch_size or <= 0 for unlimited"); } - // Parse KataGo model - KataGoParser parser(input_path); - KataGoModelDesc model = parser.parse(); + // Parse KataGo model (parser + its decompressed buffer freed at end of scope) + KataGoModelDesc model; + { + KataGoParser parser(input_path); + model = parser.parse(); + } // Determine if using FP16 precision bool use_fp16 = (options.compute_precision == "FLOAT16"); @@ -52,9 +55,8 @@ void KataGoConverter::convert(const std::string& input_path, options.use_fp16_io); auto program = builder.build(); - // Get weights from builder - auto weights = builder.getWeights(); - std::vector weights_copy(weights.begin(), weights.end()); + // Serialize directly from the builder's weight views (no copy). + std::vector& weights = builder.getWeightsMutable(); // Update options with model metadata for serialization ConversionOptions final_options = options; @@ -82,7 +84,7 @@ void KataGoConverter::convert(const std::string& input_path, // Serialize to .mlpackage CoreMLSerializer serializer(final_options.specification_version); - serializer.serialize(program.get(), weights_copy, output_path, final_options); + serializer.serialize(program.get(), weights, output_path, final_options); } ModelInfo KataGoConverter::getModelInfo(const std::string& input_path) { diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..a30d2ce43 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -212,9 +212,23 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage + // Register weight for blob storage (non-owning view into the model) m_ops.registerWeight(name, data, shape); + emitConstOp(block, name, shape); +} + +void MILBuilder::addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) { + // Register derived weight; KataGoOps takes ownership of the buffer + m_ops.registerOwnedWeight(name, std::move(data), shape); + emitConstOp(block, name, shape); +} +void MILBuilder::emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape) { // Add const operation auto* op = block->add_operations(); op->set_type("const"); @@ -958,7 +972,7 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, // Add transposed weight constant with shape [out_channels, in_channels] std::vector transposed_shape = {static_cast(out_ch), static_cast(in_ch)}; - addConstOp(block, weight_name, transposed_weights, transposed_shape); + addOwnedConstOp(block, weight_name, std::move(transposed_weights), transposed_shape); // Add bias constant std::vector bias_shape = {static_cast(bias.num_channels)}; diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..640864579 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -29,8 +29,8 @@ class MILBuilder { /// @return Unique pointer to MIL Program protobuf std::unique_ptr build(); - /// Get weight entries for blob serialization - const std::vector& getWeights() const { return m_ops.getWeights(); } + /// Get weight entries for blob serialization (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_ops.getWeightsMutable(); } /// Get board dimensions int getBoardXSize() const { return m_board_x_size; } @@ -80,6 +80,15 @@ class MILBuilder { const std::vector& data, const std::vector& shape); + void addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape); + + void emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape); + void addIntArrayConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& values); diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..148c44089 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -17,9 +17,25 @@ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& shape) { WeightEntry entry; entry.name = name; - entry.data = data; + entry.data = data.data(); + entry.count = data.size(); entry.shape = shape; - entry.blob_offset = 0; // Will be set during serialization + entry.blob_offset = 0; + m_weights.push_back(std::move(entry)); + return name; +} + +std::string KataGoOps::registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape) { + m_owned.push_back(std::move(data)); + const std::vector& stored = m_owned.back(); + WeightEntry entry; + entry.name = name; + entry.data = stored.data(); + entry.count = stored.size(); + entry.shape = shape; + entry.blob_offset = 0; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..9649cb8e6 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -5,15 +5,18 @@ #include "../types/KataGoTypes.hpp" #include +#include #include #include namespace katagocoreml { -/// Weight entry for blob file storage +/// Weight entry for blob file storage. `data`/`count` are a NON-OWNING view into +/// the live KataGoModelDesc (or into KataGoOps::m_owned for derived tensors). struct WeightEntry { std::string name; - std::vector data; + const float* data = nullptr; + size_t count = 0; std::vector shape; uint64_t blob_offset = 0; // Set during serialization }; @@ -51,16 +54,22 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight that lives in the model (stored as a non-owning view). std::string registerWeight(const std::string& name, const std::vector& data, const std::vector& shape); - /// Get all registered weights - const std::vector& getWeights() const { return m_weights; } + /// Register a derived/temporary weight; KataGoOps takes ownership so the + /// view stays valid through serialization. + std::string registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape); - /// Clear all registered weights - void clearWeights() { m_weights.clear(); } + /// Get all registered weights (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_weights; } + + /// Clear all registered weights (and their owned backing buffers) + void clearWeights() { m_weights.clear(); m_owned.clear(); } /// Generate unique operation name std::string genOpName(const std::string& prefix); @@ -71,6 +80,7 @@ class KataGoOps { bool m_optimize_identity_mask; MaskConstants m_mask_constants; std::vector m_weights; + std::deque> m_owned; int m_op_counter = 0; }; diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 2d06c27e5..19b26e90d 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -30,54 +29,41 @@ bool KataGoParser::isVersionSupported(int version) { } // ============================================================================ -// File Loading +// Stream Primitives // ============================================================================ -void KataGoParser::loadFile() { - // Check if gzip compressed - bool is_gzip = false; - if (m_model_path.size() >= 3) { - std::string ext = m_model_path.substr(m_model_path.size() - 3); - is_gzip = (ext == ".gz"); - } - - if (is_gzip) { - // Read gzipped file - gzFile gz = gzopen(m_model_path.c_str(), "rb"); - if (!gz) { - throw std::runtime_error("Cannot open gzip file: " + m_model_path); - } - - // Read in chunks - m_buffer.clear(); - std::vector chunk(1024 * 1024); // 1MB chunks - int bytes_read; - while ((bytes_read = gzread(gz, chunk.data(), static_cast(chunk.size()))) > 0) { - m_buffer.insert(m_buffer.end(), chunk.begin(), chunk.begin() + bytes_read); - } - - if (bytes_read < 0) { - int errnum; - const char* errmsg = gzerror(gz, &errnum); - gzclose(gz); - throw std::runtime_error("Error reading gzip file: " + std::string(errmsg)); - } - - gzclose(gz); - } else { - // Read regular file - std::ifstream file(m_model_path, std::ios::binary | std::ios::ate); - if (!file) { - throw std::runtime_error("Cannot open file: " + m_model_path); - } +bool KataGoParser::refill() { + if(m_gz == nullptr) return false; + int n = gzread(m_gz, m_refill.data(), (unsigned)m_refill.size()); + if(n < 0) { + int errnum; + const char* errmsg = gzerror(m_gz, &errnum); + throw std::runtime_error("Error reading gzip stream: " + std::string(errmsg)); + } + m_refillPos = 0; + m_refillLen = (size_t)n; + return n > 0; +} - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); +int KataGoParser::peekByte() { + if(m_refillPos >= m_refillLen) { + if(!refill()) return -1; + } + return (int)m_refill[m_refillPos]; +} - m_buffer.resize(static_cast(size)); - if (!file.read(reinterpret_cast(m_buffer.data()), size)) { - throw std::runtime_error("Error reading file: " + m_model_path); +void KataGoParser::readExact(uint8_t* dst, size_t n, const std::string& name) { + size_t got = 0; + while(got < n) { + if(m_refillPos >= m_refillLen) { + if(!refill()) + throw std::runtime_error(name + ": unexpected EOF in binary block"); } + size_t avail = m_refillLen - m_refillPos; + size_t take = std::min(avail, n - got); + std::memcpy(dst + got, m_refill.data() + m_refillPos, take); + m_refillPos += take; + got += take; } } @@ -86,16 +72,27 @@ void KataGoParser::loadFile() { // ============================================================================ KataGoModelDesc KataGoParser::parse() { - loadFile(); - m_pos = 0; - - // Detect if binary format (check for @BIN@ marker) - const std::string bin_marker = "@BIN@"; - auto it = std::search(m_buffer.begin(), m_buffer.end(), - bin_marker.begin(), bin_marker.end()); - m_binary_floats = (it != m_buffer.end()); - - return parseModel(); + // Allocate the refill buffer before opening the file so a bad_alloc here + // cannot leak an open gzFile handle. + m_refill.resize(1024 * 1024); + m_gz = gzopen(m_model_path.c_str(), "rb"); + if(m_gz == nullptr) + throw std::runtime_error("Cannot open file: " + m_model_path); + m_refillPos = 0; + m_refillLen = 0; + m_formatDetected = false; // decided at first readFloats + m_binary_floats = true; + KataGoModelDesc model; + try { + model = parseModel(); + } catch(...) { + gzclose(m_gz); + m_gz = nullptr; + throw; + } + gzclose(m_gz); + m_gz = nullptr; + return model; } // ============================================================================ @@ -103,24 +100,20 @@ KataGoModelDesc KataGoParser::parse() { // ============================================================================ void KataGoParser::skipWhitespace() { - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c != ' ' && c != '\t' && c != '\n' && c != '\r') { - break; - } - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c != ' ' && c != '\t' && c != '\n' && c != '\r') break; + m_refillPos++; } } void KataGoParser::readUntilWhitespace(std::string& out) { out.clear(); - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { - break; - } - out += c; - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c == ' ' || c == '\t' || c == '\n' || c == '\r') break; + out += (char)c; + m_refillPos++; } } @@ -147,37 +140,28 @@ bool KataGoParser::readBool() { std::vector KataGoParser::readFloats(size_t count, const std::string& name) { std::vector floats(count); + skipWhitespace(); + + // KataGo model files are uniformly text OR uniformly binary, so detecting the + // format once at the first weight block (binary blocks start with '@BIN@') + // is valid for all subsequent blocks. + if(!m_formatDetected) { + m_binary_floats = (peekByte() == '@'); + m_formatDetected = true; + } - if (!m_binary_floats) { + if(!m_binary_floats) { // Text format - for (size_t i = 0; i < count; i++) { + for(size_t i = 0; i < count; i++) floats[i] = readFloat(); - } } else { - // Binary format - find @BIN@ marker - while (m_pos < m_buffer.size()) { - if (m_buffer[m_pos] == '@') { - break; - } - m_pos++; - } - - // Check for @BIN@ header - if (m_pos + 5 > m_buffer.size() || - std::memcmp(&m_buffer[m_pos], "@BIN@", 5) != 0) { + // Binary: consume the "@BIN@" marker, then read count*4 raw bytes. + char marker[5]; + readExact(reinterpret_cast(marker), 5, name); + if(std::memcmp(marker, "@BIN@", 5) != 0) throw std::runtime_error(name + ": expected @BIN@ marker for binary float block"); - } - m_pos += 5; - - // Read binary floats (little-endian) - size_t num_bytes = count * 4; - if (m_pos + num_bytes > m_buffer.size()) { - throw std::runtime_error(name + ": not enough bytes for " + std::to_string(count) + " floats"); - } - // Copy as little-endian float32 - std::memcpy(floats.data(), &m_buffer[m_pos], num_bytes); - m_pos += num_bytes; + readExact(reinterpret_cast(floats.data()), count * 4, name); } // Reject NaN/Inf weights: corrupted or otherwise invalid models would diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index cbcfdefa8..2d3f1c47a 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace katagocoreml { @@ -31,9 +32,17 @@ class KataGoParser { private: std::string m_model_path; - std::vector m_buffer; - size_t m_pos = 0; + gzFile m_gz = nullptr; + std::vector m_refill; // bounded refill buffer (~1 MB) + size_t m_refillPos = 0; // read cursor within m_refill + size_t m_refillLen = 0; // valid bytes in m_refill bool m_binary_floats = true; + bool m_formatDetected = false; + + // Stream primitives + bool refill(); // returns false at EOF + int peekByte(); // -1 at EOF + void readExact(uint8_t* dst, size_t n, const std::string& name); // Low-level reading functions void readUntilWhitespace(std::string& out); @@ -65,9 +74,6 @@ class KataGoParser { // Main model parsing KataGoModelDesc parseModel(); - - // Helper to load file (handles gzip) - void loadFile(); }; } // namespace katagocoreml diff --git a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp index f271f5526..50df8003f 100644 --- a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace katagocoreml { @@ -230,8 +232,13 @@ void CoreMLSerializer::createPackage(const std::string& output_path, if (!out) { throw std::runtime_error("Failed to create temp model file"); } - if (!model->SerializeToOstream(&out)) { - throw std::runtime_error("Failed to serialize model spec"); + { + google::protobuf::io::OstreamOutputStream zos(&out); + google::protobuf::io::CodedOutputStream cos(&zos); + cos.SetSerializationDeterministic(true); + if (!model->SerializeToCodedStream(&cos)) { + throw std::runtime_error("Failed to serialize model spec"); + } } } diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..86e41aaec 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -17,18 +17,18 @@ size_t WeightSerializer::serialize(std::vector& weights, for (auto& entry : weights) { if (use_fp16) { // Convert FP32 weights to FP16 - std::vector fp16_data(entry.data.size()); - for (size_t i = 0; i < entry.data.size(); ++i) { + std::vector fp16_data(entry.count); + for (size_t i = 0; i < entry.count; ++i) { fp16_data[i] = MILBlob::Fp16::FromFloat(entry.data[i]); } MILBlob::Util::Span span(fp16_data.data(), fp16_data.size()); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(MILBlob::Fp16); + total_bytes += entry.count * sizeof(MILBlob::Fp16); } else { // Write FP32 weights - MILBlob::Util::Span span(entry.data.data(), entry.data.size()); + MILBlob::Util::Span span(entry.data, entry.count); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(float); + total_bytes += entry.count * sizeof(float); } } diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index eda55111a..e59d2e585 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -1783,6 +1783,75 @@ void ModelDesc::applyScale8ToReduceActivations() { postProcessParams.outputScaleMultiplier *= 8.0f; } +static void releaseVec(std::vector& v) { std::vector().swap(v); } + +static void releaseConv(ConvLayerDesc& c) { releaseVec(c.weights); } + +static void releaseBN(BatchNormLayerDesc& b) { + releaseVec(b.mean); releaseVec(b.variance); releaseVec(b.scale); + releaseVec(b.bias); releaseVec(b.mergedScale); releaseVec(b.mergedBias); +} + +static void releaseMatMul(MatMulLayerDesc& m) { releaseVec(m.weights); } +static void releaseMatBias(MatBiasLayerDesc& m) { releaseVec(m.weights); } + +static void releaseResidual(ResidualBlockDesc& b) { + releaseBN(b.preBN); releaseConv(b.regularConv); + releaseBN(b.midBN); releaseConv(b.finalConv); +} + +static void releaseGPool(GlobalPoolingResidualBlockDesc& b) { + releaseBN(b.preBN); releaseConv(b.regularConv); releaseConv(b.gpoolConv); + releaseBN(b.gpoolBN); releaseMatMul(b.gpoolToBiasMul); + releaseBN(b.midBN); releaseConv(b.finalConv); +} + +static void releaseBlocks(std::vector>& blocks); + +static void releaseNested(NestedBottleneckResidualBlockDesc& b) { + releaseBN(b.preBN); releaseConv(b.preConv); + releaseBlocks(b.blocks); + releaseBN(b.postBN); releaseConv(b.postConv); +} + +static void releaseBlocks(std::vector>& blocks) { + for(size_t i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) + releaseResidual(*(ResidualBlockDesc*)blocks[i].second.get()); + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) + releaseGPool(*(GlobalPoolingResidualBlockDesc*)blocks[i].second.get()); + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) + releaseNested(*(NestedBottleneckResidualBlockDesc*)blocks[i].second.get()); + else + ASSERT_UNREACHABLE; + } +} + +static void releaseSGFEncoder(SGFMetadataEncoderDesc& e) { + releaseMatMul(e.mul1); releaseMatBias(e.bias1); + releaseMatMul(e.mul2); releaseMatBias(e.bias2); + releaseMatMul(e.mul3); +} + +void ModelDesc::releaseWeights() { + releaseConv(trunk.initialConv); + releaseMatMul(trunk.initialMatMul); + if(trunk.metaEncoderVersion > 0) + releaseSGFEncoder(trunk.sgfMetadataEncoder); + releaseBlocks(trunk.blocks); + releaseBN(trunk.trunkTipBN); + releaseConv(policyHead.p1Conv); releaseConv(policyHead.g1Conv); + releaseBN(policyHead.g1BN); releaseMatMul(policyHead.gpoolToBiasMul); + releaseBN(policyHead.p1BN); releaseConv(policyHead.p2Conv); + releaseMatMul(policyHead.gpoolToPassMul); releaseMatBias(policyHead.gpoolToPassBias); + releaseMatMul(policyHead.gpoolToPassMul2); + releaseConv(valueHead.v1Conv); releaseBN(valueHead.v1BN); + releaseMatMul(valueHead.v2Mul); releaseMatBias(valueHead.v2Bias); + releaseMatMul(valueHead.v3Mul); releaseMatBias(valueHead.v3Bias); + releaseMatMul(valueHead.sv3Mul); releaseMatBias(valueHead.sv3Bias); + releaseConv(valueHead.vOwnershipConv); +} + struct NonCopyingStreamBuf : public std::streambuf { NonCopyingStreamBuf(string& str) { diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index 86676c011..9536283f2 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -389,6 +389,11 @@ struct ModelDesc { //Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made. Rules getSupportedRules(const Rules& desiredRules, bool& supported) const; + // Frees all weight arrays (conv/matmul/bias/batchnorm), keeping scalar shape + // metadata intact. Safe once weights are no longer needed (e.g. CoreML/ANE + // inference, which reads weights from the compiled .mlmodelc). + void releaseWeights(); + }; #endif // #ifndef DESC_H diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 95adf5da4..1e7db2fba 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -425,14 +425,28 @@ ComputeContext* NeuralNet::createComputeContext( enabled_t useNHWCMode, const LoadedModel* loadedModel) { - (void)gpuIdxs; + // Only ANE-only configurations may free the engine's in-memory weights: the + // GPU/MPSGraph path reads them via modelDescToSwift, so freeing is unsafe + // unless no GPU handle can ever be built from this model. + // INVARIANT: gpuIdxs must be the complete (deduplicated) set of device indices + // that will ever be passed as gpuIdxForThisThread to createComputeHandle for + // this context. aneOnly==true frees the in-memory weights, so if any thread + // later used a GPU (MPSGraph) index not represented here, it would read freed + // weights. KataGo derives both from the same gpuIdxByServerThread list, so the + // invariant holds today; preserve it if that wiring ever changes. + bool aneOnly = !gpuIdxs.empty(); + for(int idx : gpuIdxs) { + if(idx != METAL_MUX_ANE) { aneOnly = false; break; } + } (void)logger; (void)openCLTunerFile; (void)homeDataDirOverride; (void)openCLReTunePerBoardSize; (void)loadedModel; - return new ComputeContext(nnXLen, nnYLen, useFP16Mode, useNHWCMode); + ComputeContext* context = new ComputeContext(nnXLen, nnYLen, useFP16Mode, useNHWCMode); + context->aneOnly = aneOnly; + return context; } void NeuralNet::freeComputeContext(ComputeContext* computeContext) { @@ -459,6 +473,17 @@ static swift::Optional convertAndCreateCoreMLO bool useFP16 = (context->useFP16Mode != enabled_t::False); bool optimizeMask = requireExactNNLen; + // On a confirmed ANE-only run, free the engine's in-memory ModelDesc weight + // arrays. This function converts from loadedModel->modelPath (disk), + // so the in-memory weights are not read here; the GPU/MPSGraph path (which + // DOES read them via modelDescToSwift) is never built when aneOnly is true. + // The whole ComputeHandle ctor runs under computeHandleMutex, so this is not + // racy; releaseWeights() clears only weight vectors, leaving the scalar dims + // read by the ComputeHandle ctor / InputBuffers valid. + if(context->aneOnly) { + const_cast(loadedModel)->modelDesc.releaseWeights(); + } + // Convert model to CoreML format in temp directory string coremlModelPath = CoreMLConversion::convertModelToTemp( loadedModel->modelPath, diff --git a/cpp/neuralnet/metalbackend.h b/cpp/neuralnet/metalbackend.h index a00f21864..db161d28c 100644 --- a/cpp/neuralnet/metalbackend.h +++ b/cpp/neuralnet/metalbackend.h @@ -113,6 +113,14 @@ struct ComputeContext { */ MetalComputeContext metalContext; + /** + * @brief True only when every configured device is METAL_MUX_ANE, so no + * MPSGraph (GPU) handle will ever read modelDesc weights. Gates the call to + * ModelDesc::releaseWeights() so a mixed GPU+ANE config can never free live + * weights. + */ + bool aneOnly = false; + /** * @brief Constructs a ComputeContext object. * @param nnX The width of the input tensor. @@ -180,6 +188,12 @@ struct ComputeHandle { */ bool maskIdentityChecked = false; + // IMPORTANT (weight-release safety): mpsGraphOnlyHandle MUST be declared + // before coremlOnlyHandle. C++ initializes members in DECLARATION order, so + // createMPSGraphHandleIfNeeded (which reads modelDesc weights via + // modelDescToSwift) runs before createCoreMLOnlyHandleIfNeeded (which may call + // modelDesc.releaseWeights() on an ANE-only run). Reordering these would let a + // GPU handle read freed weights. Do not reorder. /** * @brief The MPSGraph-only handle instance from Swift (GPU-only mode). */ From 971fa9d8c0bd9fafd7987f25edfcb5cc96c38c1d Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sat, 30 May 2026 18:37:44 +0800 Subject: [PATCH 02/18] Enforce non-owning weight-view contract at compile time WeightEntry stores a non-owning view (const float*, count) into the live KataGoModelDesc, so the backing std::vector must outlive serialization. addConstOp/registerWeight took the data by const& and silently stored a pointer to it; a caller passing a temporary would bind to that const& and leave the view dangling, read much later during serialization. Delete the rvalue overloads of both so any such call fails to compile, forcing temporaries through addOwnedConstOp/registerOwnedWeight (which take ownership). Named lvalues (the model-member call sites) still bind to the const& overload, so no existing caller changes. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/external/katagocoreml/src/builder/MILBuilder.hpp | 9 +++++++++ cpp/external/katagocoreml/src/builder/Operations.hpp | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 640864579..a25c3f537 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -80,6 +80,15 @@ class MILBuilder { const std::vector& data, const std::vector& shape); + // addConstOp registers a NON-OWNING view into `data` (see WeightEntry), so the + // backing storage must outlive serialization. Binding a temporary here would + // dangle. Deleted so such calls fail to compile; use addOwnedConstOp for + // derived/temporary tensors that KataGoOps should own instead. + void addConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + void addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, std::vector&& data, diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 9649cb8e6..f5431f79f 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -59,6 +59,13 @@ class KataGoOps { const std::vector& data, const std::vector& shape); + /// The stored WeightEntry is a non-owning view into `data`, so a temporary + /// would leave it dangling. Deleted to reject such calls at compile time; + /// use registerOwnedWeight for tensors KataGoOps should own. + std::string registerWeight(const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + /// Register a derived/temporary weight; KataGoOps takes ownership so the /// view stays valid through serialization. std::string registerOwnedWeight(const std::string& name, From eeefc976222fcaf67fc4092300b6e39db38c634d Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sun, 31 May 2026 10:31:12 +0800 Subject: [PATCH 03/18] RAII the gzFile handle in KataGoParser Own the gzFile with a custom-deleter unique_ptr so it closes on every exit path (normal return, exception, bad_alloc); removes the manual try/catch+gzclose in parse() and the ordering caveat on buffer allocation. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/parser/KataGoParser.cpp | 26 ++++++------------- .../katagocoreml/src/parser/KataGoParser.hpp | 10 ++++++- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 19b26e90d..37497f6cc 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -33,11 +33,11 @@ bool KataGoParser::isVersionSupported(int version) { // ============================================================================ bool KataGoParser::refill() { - if(m_gz == nullptr) return false; - int n = gzread(m_gz, m_refill.data(), (unsigned)m_refill.size()); + if(!m_gz) return false; + int n = gzread(m_gz.get(), m_refill.data(), (unsigned)m_refill.size()); if(n < 0) { int errnum; - const char* errmsg = gzerror(m_gz, &errnum); + const char* errmsg = gzerror(m_gz.get(), &errnum); throw std::runtime_error("Error reading gzip stream: " + std::string(errmsg)); } m_refillPos = 0; @@ -72,27 +72,17 @@ void KataGoParser::readExact(uint8_t* dst, size_t n, const std::string& name) { // ============================================================================ KataGoModelDesc KataGoParser::parse() { - // Allocate the refill buffer before opening the file so a bad_alloc here - // cannot leak an open gzFile handle. + // Allocate the refill buffer first; if this throws, no handle has been opened. m_refill.resize(1024 * 1024); - m_gz = gzopen(m_model_path.c_str(), "rb"); - if(m_gz == nullptr) + m_gz.reset(gzopen(m_model_path.c_str(), "rb")); + if(!m_gz) throw std::runtime_error("Cannot open file: " + m_model_path); m_refillPos = 0; m_refillLen = 0; m_formatDetected = false; // decided at first readFloats m_binary_floats = true; - KataGoModelDesc model; - try { - model = parseModel(); - } catch(...) { - gzclose(m_gz); - m_gz = nullptr; - throw; - } - gzclose(m_gz); - m_gz = nullptr; - return model; + // ~GzHandle closes the file on normal return OR exception — no try/catch needed. + return parseModel(); } // ============================================================================ diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index 2d3f1c47a..8ee1a90ab 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -5,7 +5,9 @@ #include "../types/KataGoTypes.hpp" #include +#include #include +#include #include #include @@ -32,7 +34,13 @@ class KataGoParser { private: std::string m_model_path; - gzFile m_gz = nullptr; + // Custom-deleter unique_ptr owns the gzFile so it closes on every exit path + // (normal return, exception, or bad_alloc) without manual try/catch. + struct GzCloser { + void operator()(gzFile f) const noexcept { if(f) gzclose(f); } + }; + using GzHandle = std::unique_ptr, GzCloser>; + GzHandle m_gz; std::vector m_refill; // bounded refill buffer (~1 MB) size_t m_refillPos = 0; // read cursor within m_refill size_t m_refillLen = 0; // valid bytes in m_refill From 6bfa617b9fafe0be9b40b04196be7f94ed22a8f6 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sun, 31 May 2026 10:31:14 +0800 Subject: [PATCH 04/18] Replace WeightEntry raw ptr+count with a local FloatView Introduce a KataGo-local non-owning FloatView for WeightEntry::data instead of a raw const float*/size_t pair; convert to MILBlob::Util::Span only inside WeightSerializer, keeping the MILBlob dependency out of Operations.hpp. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/Operations.cpp | 6 ++---- .../katagocoreml/src/builder/Operations.hpp | 19 +++++++++++++++---- .../src/serializer/WeightSerializer.cpp | 13 +++++++------ 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index 148c44089..4cbd1038a 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -17,8 +17,7 @@ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& shape) { WeightEntry entry; entry.name = name; - entry.data = data.data(); - entry.count = data.size(); + entry.data = FloatView{data.data(), data.size()}; entry.shape = shape; entry.blob_offset = 0; m_weights.push_back(std::move(entry)); @@ -32,8 +31,7 @@ std::string KataGoOps::registerOwnedWeight(const std::string& name, const std::vector& stored = m_owned.back(); WeightEntry entry; entry.name = name; - entry.data = stored.data(); - entry.count = stored.size(); + entry.data = FloatView{stored.data(), stored.size()}; entry.shape = shape; entry.blob_offset = 0; m_weights.push_back(std::move(entry)); diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index f5431f79f..5bc8378e2 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -11,12 +11,23 @@ namespace katagocoreml { -/// Weight entry for blob file storage. `data`/`count` are a NON-OWNING view into -/// the live KataGoModelDesc (or into KataGoOps::m_owned for derived tensors). +/// Minimal non-owning view over a contiguous float buffer. KataGo-local on +/// purpose: keeps the MILBlob dependency out of this header (conversion to +/// MILBlob::Util::Span happens only at the serializer boundary). +struct FloatView { + const float* ptr = nullptr; + size_t len = 0; + const float* data() const { return ptr; } + size_t size() const { return len; } + bool empty() const { return len == 0; } + float operator[](size_t i) const { return ptr[i]; } +}; + +/// Weight entry for blob file storage. `data` is a NON-OWNING view into the live +/// KataGoModelDesc (or into KataGoOps::m_owned for derived tensors). struct WeightEntry { std::string name; - const float* data = nullptr; - size_t count = 0; + FloatView data; // non-owning view (replaces raw ptr + count) std::vector shape; uint64_t blob_offset = 0; // Set during serialization }; diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 86e41aaec..e27a342f7 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,20 +15,21 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { + const size_t count = entry.data.size(); if (use_fp16) { // Convert FP32 weights to FP16 - std::vector fp16_data(entry.count); - for (size_t i = 0; i < entry.count; ++i) { + std::vector fp16_data(count); + for (size_t i = 0; i < count; ++i) { fp16_data[i] = MILBlob::Fp16::FromFloat(entry.data[i]); } MILBlob::Util::Span span(fp16_data.data(), fp16_data.size()); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.count * sizeof(MILBlob::Fp16); + total_bytes += count * sizeof(MILBlob::Fp16); } else { - // Write FP32 weights - MILBlob::Util::Span span(entry.data, entry.count); + // Write FP32 weights — convert the KataGo-local view to a MILBlob span here. + MILBlob::Util::Span span(entry.data.data(), count); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.count * sizeof(float); + total_bytes += count * sizeof(float); } } From 415993015e80674a03d3c34b06af6a2773047483 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sun, 31 May 2026 21:57:00 +0800 Subject: [PATCH 05/18] Clarify weight-release safety comment: aneOnly is the guarantee The ComputeHandle member-order comment claimed that declaring mpsGraphOnlyHandle before coremlOnlyHandle is what prevents a GPU handle from reading freed weights. That overstates the ordering's role: within a single ComputeHandle exactly one handle is built (mutually exclusive on gpuIdx, enforced by the ctor's exactly-one check), and releaseWeights() only fires on an aneOnly context where no MPSGraph handle is ever built. Reframe the declaration order as belt-and-suspenders and point at ComputeContext::aneOnly as the actual invariant. Comment-only change. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metalbackend.h | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cpp/neuralnet/metalbackend.h b/cpp/neuralnet/metalbackend.h index db161d28c..3922d6828 100644 --- a/cpp/neuralnet/metalbackend.h +++ b/cpp/neuralnet/metalbackend.h @@ -188,12 +188,18 @@ struct ComputeHandle { */ bool maskIdentityChecked = false; - // IMPORTANT (weight-release safety): mpsGraphOnlyHandle MUST be declared - // before coremlOnlyHandle. C++ initializes members in DECLARATION order, so - // createMPSGraphHandleIfNeeded (which reads modelDesc weights via - // modelDescToSwift) runs before createCoreMLOnlyHandleIfNeeded (which may call - // modelDesc.releaseWeights() on an ANE-only run). Reordering these would let a - // GPU handle read freed weights. Do not reorder. + // Weight-release safety is guaranteed by ComputeContext::aneOnly, NOT by the + // declaration order below: within a single ComputeHandle exactly one handle is + // built (the two paths are mutually exclusive on gpuIdx, enforced by the + // ctor's exactly-one check), and releaseWeights() only ever fires on an + // aneOnly context, where no MPSGraph handle is built for any thread. + // That said, keep mpsGraphOnlyHandle declared before coremlOnlyHandle. C++ + // initializes members in DECLARATION order, so createMPSGraphHandleIfNeeded + // (which reads modelDesc weights via modelDescToSwift) is sequenced before + // createCoreMLOnlyHandleIfNeeded (which may call modelDesc.releaseWeights()). + // This ordering is belt-and-suspenders that preserves the natural read-then- + // release sequence should the aneOnly invariant ever be weakened; don't rely + // on it as the primary guarantee, but don't reorder it either. /** * @brief The MPSGraph-only handle instance from Swift (GPU-only mode). */ From 44342a388c3629ded3ed6aa6a5ca184614e6f2ab Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sun, 31 May 2026 22:58:36 +0800 Subject: [PATCH 06/18] Refactor weight release into per-struct releaseWeights() methods Replace the file-local releaseXXX free functions in desc.cpp (which reached into each desc struct's internals from outside) with releaseWeights() member methods on each weight-bearing struct, matching the existing OO convention used by applyScale8ToReduceActivations() and iterConvLayers(). Each container delegates to its members; type-erased block dispatch is inlined with the same cast pattern those methods use. Behavior-preserving: same set of freed vectors, same block recursion, same metaEncoderVersion guard. ModelDesc::releaseWeights() keeps its signature, so the metalbackend.cpp call site is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/desc.cpp | 152 ++++++++++++++++++++++++++++------------- cpp/neuralnet/desc.h | 22 ++++++ 2 files changed, 126 insertions(+), 48 deletions(-) diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index e59d2e585..0c4943bb5 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -1783,73 +1783,129 @@ void ModelDesc::applyScale8ToReduceActivations() { postProcessParams.outputScaleMultiplier *= 8.0f; } -static void releaseVec(std::vector& v) { std::vector().swap(v); } +void ConvLayerDesc::releaseWeights() { + std::vector().swap(weights); +} -static void releaseConv(ConvLayerDesc& c) { releaseVec(c.weights); } +void BatchNormLayerDesc::releaseWeights() { + std::vector().swap(mean); + std::vector().swap(variance); + std::vector().swap(scale); + std::vector().swap(bias); + std::vector().swap(mergedScale); + std::vector().swap(mergedBias); +} -static void releaseBN(BatchNormLayerDesc& b) { - releaseVec(b.mean); releaseVec(b.variance); releaseVec(b.scale); - releaseVec(b.bias); releaseVec(b.mergedScale); releaseVec(b.mergedBias); +void MatMulLayerDesc::releaseWeights() { + std::vector().swap(weights); } -static void releaseMatMul(MatMulLayerDesc& m) { releaseVec(m.weights); } -static void releaseMatBias(MatBiasLayerDesc& m) { releaseVec(m.weights); } +void MatBiasLayerDesc::releaseWeights() { + std::vector().swap(weights); +} -static void releaseResidual(ResidualBlockDesc& b) { - releaseBN(b.preBN); releaseConv(b.regularConv); - releaseBN(b.midBN); releaseConv(b.finalConv); +void ResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); } -static void releaseGPool(GlobalPoolingResidualBlockDesc& b) { - releaseBN(b.preBN); releaseConv(b.regularConv); releaseConv(b.gpoolConv); - releaseBN(b.gpoolBN); releaseMatMul(b.gpoolToBiasMul); - releaseBN(b.midBN); releaseConv(b.finalConv); +void GlobalPoolingResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + gpoolConv.releaseWeights(); + gpoolBN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); } -static void releaseBlocks(std::vector>& blocks); +void NestedBottleneckResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + preConv.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + postBN.releaseWeights(); + postConv.releaseWeights(); +} -static void releaseNested(NestedBottleneckResidualBlockDesc& b) { - releaseBN(b.preBN); releaseConv(b.preConv); - releaseBlocks(b.blocks); - releaseBN(b.postBN); releaseConv(b.postConv); +void SGFMetadataEncoderDesc::releaseWeights() { + mul1.releaseWeights(); + bias1.releaseWeights(); + mul2.releaseWeights(); + bias2.releaseWeights(); + mul3.releaseWeights(); } -static void releaseBlocks(std::vector>& blocks) { - for(size_t i = 0; i < blocks.size(); i++) { - if(blocks[i].first == ORDINARY_BLOCK_KIND) - releaseResidual(*(ResidualBlockDesc*)blocks[i].second.get()); - else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) - releaseGPool(*(GlobalPoolingResidualBlockDesc*)blocks[i].second.get()); - else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) - releaseNested(*(NestedBottleneckResidualBlockDesc*)blocks[i].second.get()); - else +void TrunkDesc::releaseWeights() { + initialConv.releaseWeights(); + initialMatMul.releaseWeights(); + if(metaEncoderVersion > 0) + sgfMetadataEncoder.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { ASSERT_UNREACHABLE; + } } + trunkTipBN.releaseWeights(); +} + +void PolicyHeadDesc::releaseWeights() { + p1Conv.releaseWeights(); + g1Conv.releaseWeights(); + g1BN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + p1BN.releaseWeights(); + p2Conv.releaseWeights(); + gpoolToPassMul.releaseWeights(); + gpoolToPassBias.releaseWeights(); + gpoolToPassMul2.releaseWeights(); } -static void releaseSGFEncoder(SGFMetadataEncoderDesc& e) { - releaseMatMul(e.mul1); releaseMatBias(e.bias1); - releaseMatMul(e.mul2); releaseMatBias(e.bias2); - releaseMatMul(e.mul3); +void ValueHeadDesc::releaseWeights() { + v1Conv.releaseWeights(); + v1BN.releaseWeights(); + v2Mul.releaseWeights(); + v2Bias.releaseWeights(); + v3Mul.releaseWeights(); + v3Bias.releaseWeights(); + sv3Mul.releaseWeights(); + sv3Bias.releaseWeights(); + vOwnershipConv.releaseWeights(); } void ModelDesc::releaseWeights() { - releaseConv(trunk.initialConv); - releaseMatMul(trunk.initialMatMul); - if(trunk.metaEncoderVersion > 0) - releaseSGFEncoder(trunk.sgfMetadataEncoder); - releaseBlocks(trunk.blocks); - releaseBN(trunk.trunkTipBN); - releaseConv(policyHead.p1Conv); releaseConv(policyHead.g1Conv); - releaseBN(policyHead.g1BN); releaseMatMul(policyHead.gpoolToBiasMul); - releaseBN(policyHead.p1BN); releaseConv(policyHead.p2Conv); - releaseMatMul(policyHead.gpoolToPassMul); releaseMatBias(policyHead.gpoolToPassBias); - releaseMatMul(policyHead.gpoolToPassMul2); - releaseConv(valueHead.v1Conv); releaseBN(valueHead.v1BN); - releaseMatMul(valueHead.v2Mul); releaseMatBias(valueHead.v2Bias); - releaseMatMul(valueHead.v3Mul); releaseMatBias(valueHead.v3Bias); - releaseMatMul(valueHead.sv3Mul); releaseMatBias(valueHead.sv3Bias); - releaseConv(valueHead.vOwnershipConv); + trunk.releaseWeights(); + policyHead.releaseWeights(); + valueHead.releaseWeights(); } struct NonCopyingStreamBuf : public std::streambuf diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index 9536283f2..6b612207d 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -34,6 +34,8 @@ struct ConvLayerDesc { double getSpatialConvDepth() const; void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct BatchNormLayerDesc { @@ -64,6 +66,8 @@ struct BatchNormLayerDesc { void extractChannelFactorsAbsLtOne(std::vector& channelFactors); void extractChannelFactorsAbsLtOneWithInverses(std::vector& channelFactors, std::vector& invChannelFactors); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ActivationLayerDesc { @@ -99,6 +103,8 @@ struct MatMulLayerDesc { MatMulLayerDesc& operator=(MatMulLayerDesc&& other); void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct MatBiasLayerDesc { @@ -115,6 +121,8 @@ struct MatBiasLayerDesc { MatBiasLayerDesc& operator=(MatBiasLayerDesc&& other); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ResidualBlockDesc { @@ -140,6 +148,8 @@ struct ResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct GlobalPoolingResidualBlockDesc { @@ -170,6 +180,8 @@ struct GlobalPoolingResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct NestedBottleneckResidualBlockDesc { @@ -200,6 +212,8 @@ struct NestedBottleneckResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct SGFMetadataEncoderDesc { @@ -223,6 +237,8 @@ struct SGFMetadataEncoderDesc { SGFMetadataEncoderDesc& operator=(const SGFMetadataEncoderDesc&) = delete; SGFMetadataEncoderDesc& operator=(SGFMetadataEncoderDesc&& other); + + void releaseWeights(); }; @@ -263,6 +279,8 @@ struct TrunkDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct PolicyHeadDesc { @@ -296,6 +314,8 @@ struct PolicyHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ValueHeadDesc { @@ -327,6 +347,8 @@ struct ValueHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ModelPostProcessParams { From 98b17ebbf2459e00dcca8120e83e766aff335ab0 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Sun, 31 May 2026 23:22:13 +0800 Subject: [PATCH 07/18] Co-locate releaseWeights() defs with each struct's other methods Move the 11 leaf/container releaseWeights() definitions in desc.cpp out of the bottom cluster (inherited from the old free-function layout) and place each immediately after its struct's last existing method, matching the file's per-struct grouping convention used by every other method. ModelDesc::releaseWeights() stays put, already adjacent to its siblings. Pure relocation: function bodies and desc.h are unchanged; only two stray double-blank lines were normalized to single. Verified clean Metal build, testgpuerror vs Eigen reference (g170-b6c96) at <0.0004% winrate error, and runtests all pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/desc.cpp | 236 ++++++++++++++++++++--------------------- 1 file changed, 117 insertions(+), 119 deletions(-) diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 0c4943bb5..2fc47f340 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -193,6 +193,10 @@ void ConvLayerDesc::scaleOutputChannels(const std::vector& scaling) { } } +void ConvLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- BatchNormLayerDesc::BatchNormLayerDesc() : numChannels(0), epsilon(0.001f), hasScale(false), hasBias(false) {} @@ -377,6 +381,15 @@ ActivationLayerDesc::ActivationLayerDesc(istream& in, int modelVersion) { } } +void BatchNormLayerDesc::releaseWeights() { + std::vector().swap(mean); + std::vector().swap(variance); + std::vector().swap(scale); + std::vector().swap(bias); + std::vector().swap(mergedScale); + std::vector().swap(mergedBias); +} + ActivationLayerDesc::ActivationLayerDesc(ActivationLayerDesc&& other) { *this = std::move(other); } @@ -487,6 +500,10 @@ MatBiasLayerDesc::MatBiasLayerDesc(istream& in, bool binaryFloats) { throw StringError(name + ": matbiaslayer failed to parse expected number of matbias weights"); } +void MatMulLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + MatBiasLayerDesc::MatBiasLayerDesc(MatBiasLayerDesc&& other) { *this = std::move(other); } @@ -504,6 +521,10 @@ void MatBiasLayerDesc::applyScale8ToReduceActivations() { } } +void MatBiasLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- ResidualBlockDesc::ResidualBlockDesc() {} @@ -575,6 +596,13 @@ void ResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void ResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- GlobalPoolingResidualBlockDesc::GlobalPoolingResidualBlockDesc() {} @@ -685,6 +713,16 @@ void GlobalPoolingResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void GlobalPoolingResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + gpoolConv.releaseWeights(); + gpoolBN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- NestedBottleneckResidualBlockDesc::NestedBottleneckResidualBlockDesc() {} @@ -847,6 +885,30 @@ void NestedBottleneckResidualBlockDesc::applyScale8ToReduceActivations() { postActivation.applyScale8ToReduceActivations(); } +void NestedBottleneckResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + preConv.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + postBN.releaseWeights(); + postConv.releaseWeights(); +} + //----------------------------------------------------------------------------- static void parseResidualBlockStack( @@ -1009,6 +1071,14 @@ SGFMetadataEncoderDesc& SGFMetadataEncoderDesc::operator=(SGFMetadataEncoderDesc return *this; } +void SGFMetadataEncoderDesc::releaseWeights() { + mul1.releaseWeights(); + bias1.releaseWeights(); + mul2.releaseWeights(); + bias2.releaseWeights(); + mul3.releaseWeights(); +} + //----------------------------------------------------------------------------- TrunkDesc::TrunkDesc() @@ -1259,6 +1329,30 @@ void TrunkDesc::applyScale8ToReduceActivations() { } } +void TrunkDesc::releaseWeights() { + initialConv.releaseWeights(); + initialMatMul.releaseWeights(); + if(metaEncoderVersion > 0) + sgfMetadataEncoder.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + trunkTipBN.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -1406,6 +1500,18 @@ void PolicyHeadDesc::applyScale8ToReduceActivations() { passActivation.applyScale8ToReduceActivations(); } +void PolicyHeadDesc::releaseWeights() { + p1Conv.releaseWeights(); + g1Conv.releaseWeights(); + g1BN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + p1BN.releaseWeights(); + p2Conv.releaseWeights(); + gpoolToPassMul.releaseWeights(); + gpoolToPassBias.releaseWeights(); + gpoolToPassMul2.releaseWeights(); +} + //----------------------------------------------------------------------------- ValueHeadDesc::ValueHeadDesc() : modelVersion(-1) {} @@ -1541,6 +1647,17 @@ void ValueHeadDesc::applyScale8ToReduceActivations() { sv3Bias.applyScale8ToReduceActivations(); } +void ValueHeadDesc::releaseWeights() { + v1Conv.releaseWeights(); + v1BN.releaseWeights(); + v2Mul.releaseWeights(); + v2Bias.releaseWeights(); + v3Mul.releaseWeights(); + v3Bias.releaseWeights(); + sv3Mul.releaseWeights(); + sv3Bias.releaseWeights(); + vOwnershipConv.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -1783,125 +1900,6 @@ void ModelDesc::applyScale8ToReduceActivations() { postProcessParams.outputScaleMultiplier *= 8.0f; } -void ConvLayerDesc::releaseWeights() { - std::vector().swap(weights); -} - -void BatchNormLayerDesc::releaseWeights() { - std::vector().swap(mean); - std::vector().swap(variance); - std::vector().swap(scale); - std::vector().swap(bias); - std::vector().swap(mergedScale); - std::vector().swap(mergedBias); -} - -void MatMulLayerDesc::releaseWeights() { - std::vector().swap(weights); -} - -void MatBiasLayerDesc::releaseWeights() { - std::vector().swap(weights); -} - -void ResidualBlockDesc::releaseWeights() { - preBN.releaseWeights(); - regularConv.releaseWeights(); - midBN.releaseWeights(); - finalConv.releaseWeights(); -} - -void GlobalPoolingResidualBlockDesc::releaseWeights() { - preBN.releaseWeights(); - regularConv.releaseWeights(); - gpoolConv.releaseWeights(); - gpoolBN.releaseWeights(); - gpoolToBiasMul.releaseWeights(); - midBN.releaseWeights(); - finalConv.releaseWeights(); -} - -void NestedBottleneckResidualBlockDesc::releaseWeights() { - preBN.releaseWeights(); - preConv.releaseWeights(); - for(int i = 0; i < blocks.size(); i++) { - if(blocks[i].first == ORDINARY_BLOCK_KIND) { - ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { - GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { - NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else { - ASSERT_UNREACHABLE; - } - } - postBN.releaseWeights(); - postConv.releaseWeights(); -} - -void SGFMetadataEncoderDesc::releaseWeights() { - mul1.releaseWeights(); - bias1.releaseWeights(); - mul2.releaseWeights(); - bias2.releaseWeights(); - mul3.releaseWeights(); -} - -void TrunkDesc::releaseWeights() { - initialConv.releaseWeights(); - initialMatMul.releaseWeights(); - if(metaEncoderVersion > 0) - sgfMetadataEncoder.releaseWeights(); - for(int i = 0; i < blocks.size(); i++) { - if(blocks[i].first == ORDINARY_BLOCK_KIND) { - ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { - GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { - NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); - desc->releaseWeights(); - } - else { - ASSERT_UNREACHABLE; - } - } - trunkTipBN.releaseWeights(); -} - -void PolicyHeadDesc::releaseWeights() { - p1Conv.releaseWeights(); - g1Conv.releaseWeights(); - g1BN.releaseWeights(); - gpoolToBiasMul.releaseWeights(); - p1BN.releaseWeights(); - p2Conv.releaseWeights(); - gpoolToPassMul.releaseWeights(); - gpoolToPassBias.releaseWeights(); - gpoolToPassMul2.releaseWeights(); -} - -void ValueHeadDesc::releaseWeights() { - v1Conv.releaseWeights(); - v1BN.releaseWeights(); - v2Mul.releaseWeights(); - v2Bias.releaseWeights(); - v3Mul.releaseWeights(); - v3Bias.releaseWeights(); - sv3Mul.releaseWeights(); - sv3Bias.releaseWeights(); - vOwnershipConv.releaseWeights(); -} - void ModelDesc::releaseWeights() { trunk.releaseWeights(); policyHead.releaseWeights(); From b4459dada3bd288fd108eac15b678e1ab372abf8 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:40:19 +0800 Subject: [PATCH 08/18] Add Metal GPU + CoreML/ANE transformer support for b10c384h6nbttflrs (v15) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the LLaMA-style transformer-hybrid forward pass (RMSNorm, multi-head attention with learnable 2D RoPE, SwiGLU FFN) plus ACTIVATION_SILU across the Metal GPU (MPSGraph) and CoreML/ANE (MIL) backends, so the v15 b10c384h6nbttflrs model runs end-to-end. Metal GPU (MPSGraph) — verified via testgpuerror vs Eigen reference at sizes 9/13/19 (winrate error ~0.0001%, well under threshold): - metallayers.swift: TransformerRMSNormLayer, TrunkRMSNormLayer, TransformerAttentionBlock, TransformerFFNBlock, silu() activation, SWTransformer*/SWRMSNorm descriptors; Trunk branches on trunkNormKind - metalbackend.cpp: SILU bridge + transformer/RMSNorm desc bridges, wired into residualBlocksToSwift and trunkDescToSwift CoreML/ANE (katagocoreml MIL) — implemented end-to-end; fp32 model logically correct and consistent across CPU/ANE/GPU. fp16 ANE path is numerically precision-limited (~5%) due to fp16 matmul accumulation in the deep attention stack: - types/parser: ActivationType::Silu, trunk_norm_kind, transformer block kinds 4/5, RMSNorm/attention/FFN descriptors - MILBuilder: addSiluOps, RMSNorm ops, transformer attention/FFN blocks. Fixes 4 CoreML bugs: reshape-after-transpose, fp16 mask overflow, fp16 RMSNorm reduce_sum overflow (reduce_mean) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 595 +++++++++++++++++- .../katagocoreml/src/builder/MILBuilder.hpp | 38 ++ .../katagocoreml/src/parser/KataGoParser.cpp | 136 +++- .../katagocoreml/src/parser/KataGoParser.hpp | 4 + .../katagocoreml/src/types/KataGoTypes.hpp | 68 +- cpp/neuralnet/metalbackend.cpp | 79 ++- cpp/neuralnet/metallayers.swift | 578 ++++++++++++++++- 7 files changed, 1467 insertions(+), 31 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..2a0bbf44a 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -4,6 +4,7 @@ #include "MILBuilder.hpp" #include "MILBlob/Fp16.hpp" #include +#include // Include generated protobuf headers #include "MIL.pb.h" @@ -732,6 +733,63 @@ void MILBuilder::addBatchNormActivationOps(CoreML::Specification::MILSpec::Block setTensorOutput4D(op, output, bn.num_channels, m_board_y_size, m_board_x_size); } else if (act.activation_type == ActivationType::Mish) { addMishOps(block, bn_output, output, 4, bn.num_channels); + } else if (act.activation_type == ActivationType::Silu) { + addSiluOps(block, bn_output, output, 4, bn.num_channels); + } +} + +void MILBuilder::addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels) { + // SiLU / Swish: x * sigmoid(x) + auto setOutputType = [this, rank, channels](CoreML::Specification::MILSpec::Operation* op, const std::string& name) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* out_type = out->mutable_type()->mutable_tensortype(); + out_type->set_datatype(m_weight_dtype); + out_type->set_rank(rank); + setBatchDimension(out_type); + out_type->add_dimensions()->mutable_constant()->set_size(channels); + if (rank == 4) { + out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); + out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + } + }; + + std::string sig = output + "_sigmoid"; + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + setOutputType(op, sig); + } + { + auto* op = block->add_operations(); + op->set_type("mul"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + inputs["y"].add_arguments()->set_name(sig); + setOutputType(op, output); + } +} + +void MILBuilder::setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* t = out->mutable_type()->mutable_tensortype(); + t->set_datatype(m_weight_dtype); + t->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + auto* dim = t->add_dimensions(); + if (d < 0) + dim->mutable_unknown()->set_variadic(false); + else + dim->mutable_constant()->set_size(d); } } @@ -1637,6 +1695,522 @@ void MILBuilder::addGlobalPoolingValueOps(CoreML::Specification::MILSpec::Block* // Network Component Builders // ============================================================================ +// --------------------------------------------------------------------------- +// Transformer blocks (MIL). Layout is NCHW [B, C, H, W]; spatial positions +// (H*W, ordered y*W+x) are treated as the attention sequence. RoPE is applied +// via a fixed pair-rotation matmul plus host-precomputed cos/sin tables, which +// keeps every tensor rank <= 4 (ANE-friendly). +// --------------------------------------------------------------------------- + +std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", input, input, sq, {-1, C, H, W}); + // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so + // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can + // overflow FP16 (and the FP16 accumulation on ANE) for large activations. + std::string meanSq = genVarName(prefix + "_meansq"); + { + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + } + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string inv = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, inv, {-1, 1, H, W}); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string weightName = prefix + "_weight"; + addConstOp(block, weightName, desc.weight, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, weightName, scaled, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + emit2("mul", scaled, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + auto reduceSum = [&](const std::string& x, const std::string& out, const std::vector& axes, + const std::vector& dims) { + std::string axesName = out + "_axes"; + std::string keepName = out + "_keep"; + addIntArrayConstOp(block, axesName, axes); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_sum"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, out, dims); + }; + + std::string masked = genVarName(prefix + "_premask"); + emit2("mul", input, mask, masked, {-1, C, H, W}); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", masked, masked, sq, {-1, C, H, W}); + + std::string meanSq; + std::vector denomDims; + if (desc.spatial) { + // Mean of squares over valid positions and channels. A reduce_sum over C*H*W elements + // overflows FP16 (e.g. trunk tip with large activations on ANE -> inf -> rsqrt 0 -> + // collapse). Instead take reduce_mean over all of C,H,W (masked positions are zero) and + // rescale by totalPositions/validCount to restrict the mean to valid positions. + std::string meanAll = genVarName(prefix + "_meanall"); + { + std::string axesName = meanAll + "_axes", keepName = meanAll + "_keep"; + addIntArrayConstOp(block, axesName, {1, 2, 3}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanAll, {-1, 1, 1, 1}); + } + std::string count = genVarName(prefix + "_count"); + reduceSum(mask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + std::string totalPosName = prefix + "_totalpos"; + addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); + std::string scaleF = genVarName(prefix + "_scalef"); + emit2("real_div", totalPosName, count, scaleF, {-1, 1, 1, 1}); // totalPos / validCount + meanSq = genVarName(prefix + "_meansq"); + emit2("mul", meanAll, scaleF, meanSq, {-1, 1, 1, 1}); + denomDims = {-1, 1, 1, 1}; + } else { + meanSq = genVarName(prefix + "_meansq"); + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + denomDims = {-1, 1, H, W}; + } + + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string inv = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, inv, denomDims); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string gammaName = prefix + "_gamma"; + std::string betaName = prefix + "_beta"; + addConstOp(block, gammaName, desc.gamma, {1, static_cast(C), 1, 1}); + addConstOp(block, betaName, desc.beta, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, gammaName, scaled, {-1, C, H, W}); + std::string biased = genVarName(prefix + "_biased"); + emit2("add", scaled, betaName, biased, {-1, C, H, W}); + + std::string activated; + if (act.activation_type == ActivationType::Silu) { + activated = genVarName(prefix + "_act"); + addSiluOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::Mish) { + activated = genVarName(prefix + "_act"); + addMishOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::ReLU) { + activated = genVarName(prefix + "_act"); + auto* op = block->add_operations(); + op->set_type("relu"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(biased); + setShape(op, activated, {-1, C, H, W}); + } else { + activated = biased; + } + std::string out = genVarName(prefix + "_out"); + emit2("mul", activated, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.q_proj.in_channels; + const int H = m_board_y_size, W = m_board_x_size; + const int seq = H * W; + const int numHeads = desc.num_heads, numKVHeads = desc.num_kv_heads; + const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; + const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; + + if (numKVHeads != numHeads) { + throw std::runtime_error(desc.name + ": GQA (numKVHeads != numHeads) not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims, bool transX, bool transY) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, transX); + addBoolScalarConstOp(block, tyName, transY); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { + std::string wName = nm + "_w"; + addConstOp(block, wName, w.weights, w.getWeightShape()); + std::string out = genVarName(nm); + matmul(x2d, wName, out, {-1, total}, false, false); + return out; + }; + std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); + std::string k2d = proj(desc.k_proj, prefix + "_k", kTotal); + std::string v2d = proj(desc.v_proj, prefix + "_v", vTotal); + auto toHeads = [&](const std::string& in2d, const std::string& nm, int nh, int hd) { + std::string r = genVarName(nm + "_r"); + reshape(in2d, r, {-1, seq, nh, hd}, {-1, seq, nh, hd}); + std::string t = genVarName(nm + "_t"); + transpose(r, t, {0, 2, 1, 3}, {-1, nh, seq, hd}); + return t; + }; + std::string qh = toHeads(q2d, prefix + "_qh", numHeads, qHeadDim); + std::string kh = toHeads(k2d, prefix + "_kh", numKVHeads, qHeadDim); + std::string vh = toHeads(v2d, prefix + "_vh", numKVHeads, vHeadDim); + + if (desc.use_rope) { + const int numPairs = qHeadDim / 2; + const int numPairsPerDim = numPairs / 2; + const int dimHalf = qHeadDim / 2; + auto applyRope = [&](const std::string& x, int nh, const std::string& tag) { + std::vector cosFull(static_cast(nh) * seq * qHeadDim, 0.0f); + std::vector sinFull(static_cast(nh) * seq * qHeadDim, 0.0f); + for (int h = 0; h < nh; h++) { + int kvh = (h * numKVHeads) / nh; + for (int xy = 0; xy < seq; xy++) { + int y = xy / W; + int x = xy % W; + for (int p = 0; p < numPairs; p++) { + float angle = 0.0f; + if (desc.learnable_rope) { + float fx = desc.rope_freqs[(kvh * numPairs + p) * 2 + 0]; + float fy = desc.rope_freqs[(kvh * numPairs + p) * 2 + 1]; + angle = static_cast(x) * fx + static_cast(y) * fy; + } else { + if (p < numPairsPerDim) { + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * p) / dimHalf); + angle = static_cast(y) * freq; + } else { + int pAdj = p - numPairsPerDim; + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * pAdj) / dimHalf); + angle = static_cast(x) * freq; + } + } + float c = std::cos(angle), s = std::sin(angle); + size_t base = (static_cast(h) * seq + xy) * qHeadDim + 2 * p; + cosFull[base] = c; cosFull[base + 1] = c; + sinFull[base] = s; sinFull[base + 1] = s; + } + } + } + std::vector R(static_cast(qHeadDim) * qHeadDim, 0.0f); + for (int p = 0; p < numPairs; p++) { + R[(2 * p) * qHeadDim + (2 * p + 1)] = 1.0f; + R[(2 * p + 1) * qHeadDim + (2 * p)] = -1.0f; + } + std::string cosName = prefix + "_" + tag + "_cos"; + std::string sinName = prefix + "_" + tag + "_sin"; + std::string rName = prefix + "_" + tag + "_R"; + addConstOp(block, cosName, cosFull, {1, nh, seq, qHeadDim}); + addConstOp(block, sinName, sinFull, {1, nh, seq, qHeadDim}); + // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. + addConstOp(block, rName, R, {1, 1, qHeadDim, qHeadDim}); + std::string rotated = genVarName(prefix + "_" + tag + "_rot"); + matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); + std::string xc = genVarName(prefix + "_" + tag + "_xc"); + binary("mul", x, cosName, xc, {-1, nh, seq, qHeadDim}); + std::string rs = genVarName(prefix + "_" + tag + "_rs"); + binary("mul", rotated, sinName, rs, {-1, nh, seq, qHeadDim}); + std::string out = genVarName(prefix + "_" + tag + "_rope"); + binary("add", xc, rs, out, {-1, nh, seq, qHeadDim}); + return out; + }; + qh = applyRope(qh, numHeads, "q"); + kh = applyRope(kh, numKVHeads, "k"); + } + + std::string scores = genVarName(prefix + "_scores"); + matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); + std::string scaleName = prefix + "_scale"; + addFloatScalarConstOp(block, scaleName, 1.0f / std::sqrt(static_cast(qHeadDim))); + std::string scaled = genVarName(prefix + "_sc"); + binary("mul", scores, scaleName, scaled, {-1, numHeads, seq, seq}); + + // mask [B,1,H,W] -> [B,1,1,seq] directly (contiguous reshape; H,W already trailing so the + // row-major flatten gives seq index xy=y*W+x). No transpose -> avoids the reshape-after- + // transpose issue, and is also correct for non-full boards. + std::string maskSeq = genVarName(prefix + "_mseq"); + reshape(mask, maskSeq, {-1, 1, 1, seq}, {-1, 1, 1, seq}); + std::string oneName = prefix + "_one"; + addFloatScalarConstOp(block, oneName, 1.0f); + std::string mm1 = genVarName(prefix + "_mm1"); + binary("sub", maskSeq, oneName, mm1, {-1, 1, 1, seq}); + // Use an FP16-safe magnitude: 1e9 overflows FP16 to +inf, and for valid keys + // (maskSeq-1 == 0) the product 0 * inf becomes NaN, poisoning the whole softmax. + // 1e4 is well within FP16 range and exp(score - 1e4) still underflows to 0. + std::string bigName = prefix + "_big"; + addFloatScalarConstOp(block, bigName, 1.0e4f); + std::string keyBias = genVarName(prefix + "_kb"); + binary("mul", mm1, bigName, keyBias, {-1, 1, 1, seq}); + std::string scoresMasked = genVarName(prefix + "_scm"); + binary("add", scaled, keyBias, scoresMasked, {-1, numHeads, seq, seq}); + + std::string attn = genVarName(prefix + "_attn"); + { + std::string axisName = attn + "_axis"; + addIntScalarConstOp(block, axisName, 3); + auto* op = block->add_operations(); + op->set_type("softmax"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(scoresMasked); + (*op->mutable_inputs())["axis"].add_arguments()->set_name(axisName); + setShape(op, attn, {-1, numHeads, seq, seq}); + } + + std::string attnOut = genVarName(prefix + "_ao"); + matmul(attn, vh, attnOut, {-1, numHeads, seq, vHeadDim}, false, false); + + // Output projection, done per-head to avoid reshape-after-transpose: CoreML's reshape + // ignores an immediately-preceding transpose, so merging [head,dim]->channels after a + // transpose scrambles the data. Instead slice each head from attnOut (head is the + // contiguous axis 1), reshape (leading-merge only), matmul its weight slice, and sum. + // out[b,s,c] = sum_h sum_d attnOut[b,h,s,d] * outProj.weights[(h*vHeadDim+d)*outC + c] + const int outC = desc.out_proj.out_channels; + std::string proj2d; + for (int h = 0; h < numHeads; h++) { + std::string aoh = genVarName(prefix + "_aoh"); + { + std::string beginName = aoh + "_begin", sizeName = aoh + "_size"; + addIntArrayConstOp(block, beginName, {0, h, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, vHeadDim}); + auto* op = block->add_operations(); + op->set_type("slice_by_size"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(attnOut); + (*op->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*op->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(op, aoh, {-1, 1, seq, vHeadDim}); + } + std::string aoh2d = genVarName(prefix + "_aoh2d"); + reshape(aoh, aoh2d, {-1, vHeadDim}, {-1, vHeadDim}); // [B*seq, vHeadDim] + std::string wh = prefix + "_ow" + std::to_string(h); + std::vector whData(static_cast(vHeadDim) * outC); + for (int d = 0; d < vHeadDim; d++) + for (int c = 0; c < outC; c++) + whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; + addConstOp(block, wh, whData, {vHeadDim, outC}); + std::string contrib = genVarName(prefix + "_contrib"); + matmul(aoh2d, wh, contrib, {-1, outC}, false, false); + if (h == 0) { + proj2d = contrib; + } else { + std::string acc = genVarName(prefix + "_acc"); + binary("add", proj2d, contrib, acc, {-1, outC}); + proj2d = acc; + } + } + std::string projNHWC = genVarName(prefix + "_pnhwc"); + reshape(proj2d, projNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string projNCHW = genVarName(prefix + "_pnchw"); + transpose(projNHWC, projNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", projNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int ffn = desc.ffn_channels; + const int H = m_board_y_size, W = m_board_x_size; + + if (!desc.use_swiglu) { + throw std::runtime_error(desc.name + ": non-SwiGLU transformer FFN not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, false); + addBoolScalarConstOp(block, tyName, false); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + + std::string w1 = prefix + "_w1"; + addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); + std::string a = genVarName(prefix + "_a"); + matmul(x2d, w1, a, {-1, ffn}); + std::string wg = prefix + "_wg"; + addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string g = genVarName(prefix + "_g"); + matmul(x2d, wg, g, {-1, ffn}); + + std::string sig = genVarName(prefix + "_sig"); + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(a); + setShape(op, sig, {-1, ffn}); + } + std::string siluA = genVarName(prefix + "_silu"); + binary("mul", a, sig, siluA, {-1, ffn}); + std::string h = genVarName(prefix + "_h"); + binary("mul", siluA, g, h, {-1, ffn}); + + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + std::string o = genVarName(prefix + "_o"); + matmul(h, w2, o, {-1, C}); + + std::string oNHWC = genVarName(prefix + "_onhwc"); + reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string oNCHW = genVarName(prefix + "_onchw"); + transpose(oNHWC, oNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", oNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, const std::string& spatial_input, const std::string& global_input, @@ -1747,12 +2321,23 @@ std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, } else if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { const auto& block_desc = std::get(*entry.block); x = buildNestedBottleneckBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, block_desc, mask, prefix); } } // Trunk tip - std::string trunk_out = genVarName("trunk_tip"); - addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + std::string trunk_out; + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk_out = genVarName("trunk_tip"); + addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + } else { + trunk_out = addTrunkRMSNorm(block, x, trunk.trunk_tip_rms_norm, trunk.trunk_tip_activation, mask, "trunk_tip_rms"); + } return trunk_out; } @@ -1898,6 +2483,12 @@ std::string MILBuilder::buildNestedBottleneckBlock(CoreML::Specification::MILSpe } else if (entry.block_kind == GLOBAL_POOLING_BLOCK_KIND) { const auto& nested = std::get(*entry.block); x = buildGlobalPoolingResidualBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, nested, mask, nested_prefix); } } diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..5d25b963a 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -120,6 +120,44 @@ class MILBuilder { int rank, int channels); + void addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels); + + // Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions. + void setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims); + + // Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out. + std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix); + + // Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out. + std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block, const std::string& input, const std::string& mask, diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 2d06c27e5..5dcb80f5d 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -315,6 +315,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) { layer.activation_type = ActivationType::ReLU; } else if (activation_str == "ACTIVATION_MISH") { layer.activation_type = ActivationType::Mish; + } else if (activation_str == "ACTIVATION_SILU") { + layer.activation_type = ActivationType::Silu; } else { throw std::runtime_error("Unknown activation type: " + activation_str); } @@ -420,6 +422,98 @@ static void checkBlockChannels(const std::string& block_name, const std::string& } } +TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() { + TransformerRMSNormDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1"); + } + layer.weight = readFloats(layer.num_channels, layer.name + "/weight"); + return layer; +} + +RMSNormLayerDesc KataGoParser::parseRMSNormLayer() { + RMSNormLayerDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + layer.spatial = (readInt() != 0); + layer.cgroup_size = readInt(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1"); + } + if (layer.cgroup_size != 0) { + throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported"); + } + layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma"); + layer.beta = readFloats(layer.num_channels, layer.name + "/beta"); + return layer; +} + +TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version) { + TransformerAttentionBlockDesc block; + block.name = readString(); + block.num_heads = readInt(); + block.num_kv_heads = readInt(); + block.q_head_dim = readInt(); + block.v_head_dim = readInt(); + block.use_rope = (readInt() != 0); + block.learnable_rope = (readInt() != 0); + + if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) { + throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads"); + } + if (block.use_rope && (block.q_head_dim % 2 != 0)) { + throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used"); + } + + block.pre_ln = parseTransformerRMSNorm(); + block.q_proj = parseMatMulLayer(); + block.k_proj = parseMatMulLayer(); + block.v_proj = parseMatMulLayer(); + block.out_proj = parseMatMulLayer(); + + if (block.use_rope) { + if (block.learnable_rope) { + readString(); // ropeFreqs name + block.rope_num_kv_heads = readInt(); + block.rope_num_pairs = readInt(); + int rope_dim2 = readInt(); + if (block.rope_num_kv_heads != block.num_kv_heads || + block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) { + throw std::runtime_error(block.name + ": invalid learnable rope header"); + } + block.rope_freqs = readFloats( + static_cast(block.rope_num_kv_heads) * block.rope_num_pairs * 2, + block.name + "/rope_freqs"); + } else { + readString(); // ropeTheta name + block.rope_theta = readFloat(); + } + } + return block; +} + +TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version) { + TransformerFFNBlockDesc block; + block.name = readString(); + block.num_channels = readInt(); + block.ffn_channels = readInt(); + block.use_swiglu = (readInt() != 0); + if (block.num_channels < 1 || block.ffn_channels < 1) { + throw std::runtime_error(block.name + ": transformer ffn channels must be positive"); + } + block.pre_ln = parseTransformerRMSNorm(); + block.linear1 = parseMatMulLayer(); + if (block.use_swiglu) { + block.linear_gate = parseMatMulLayer(); + } + block.linear2 = parseMatMulLayer(); + return block; +} + std::vector KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) { std::vector blocks; blocks.reserve(num_blocks); @@ -449,6 +543,14 @@ std::vector KataGoParser::parseBlockStack(int model_version, int num desc.pre_bn.num_channels, desc.post_conv.out_channels, trunk_num_channels); entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_attention_block") { + entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND; + auto desc = parseTransformerAttentionBlock(model_version); + entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_ffn_block") { + entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND; + auto desc = parseTransformerFFNBlock(model_version); + entry.block = std::make_shared(std::move(desc)); } else { throw std::runtime_error("Unknown block kind: " + block_kind_name); } @@ -505,11 +607,17 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) std::to_string(trunk.gpool_num_channels) + ")"); } - // Version >= 15 has 6 unused int parameters + // Version >= 15: first int is trunkNormKind, followed by 5 unused ints. if (model_version >= 15) { - for (int i = 0; i < 6; i++) { + trunk.trunk_norm_kind = readInt(); + for (int i = 0; i < 5; i++) { readInt(); } + if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD && + trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) { + throw std::runtime_error(trunk.name + ": unknown trunk norm kind: " + + std::to_string(trunk.trunk_norm_kind)); + } } trunk.initial_conv = parseConvLayer(); @@ -548,14 +656,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) // Parse residual blocks trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels); - trunk.trunk_tip_bn = parseBatchNormLayer(); - trunk.trunk_tip_activation = parseActivationLayer(model_version); - if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { - throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + - std::to_string(trunk.trunk_tip_bn.num_channels) + - ") != trunkNumChannels (" + - std::to_string(trunk.trunk_num_channels) + ")"); + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk.trunk_tip_bn = parseBatchNormLayer(); + if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + + std::to_string(trunk.trunk_tip_bn.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } + } else { + trunk.trunk_tip_rms_norm = parseRMSNormLayer(); + if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" + + std::to_string(trunk.trunk_tip_rms_norm.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } } + trunk.trunk_tip_activation = parseActivationLayer(model_version); return trunk; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index cbcfdefa8..9c9935b57 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -50,11 +50,15 @@ class KataGoParser { ActivationLayerDesc parseActivationLayer(int model_version); MatMulLayerDesc parseMatMulLayer(); MatBiasLayerDesc parseMatBiasLayer(); + TransformerRMSNormDesc parseTransformerRMSNorm(); + RMSNormLayerDesc parseRMSNormLayer(); // Block parsing functions ResidualBlockDesc parseResidualBlock(int model_version); GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version); NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels); + TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version); + TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version); std::vector parseBlockStack(int model_version, int num_blocks, int trunk_num_channels); // Component parsing functions diff --git a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp index 284b26cd3..1e0ae5f12 100644 --- a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp +++ b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp @@ -19,10 +19,15 @@ namespace katagocoreml { enum class ActivationType : int { Identity = 0, ReLU = 1, - Mish = 2 + Mish = 2, + Silu = 3 // MISH_SCALE8 = 12 is internal optimization, treated as Mish }; +/// Trunk normalization kind (matching KataGo's desc.h) +constexpr int TRUNK_NORM_KIND_STANDARD = 0; +constexpr int TRUNK_NORM_KIND_RMSNORM = 1; + // ============================================================================ // Block Kind Constants // ============================================================================ @@ -31,6 +36,8 @@ enum class ActivationType : int { constexpr int ORDINARY_BLOCK_KIND = 0; constexpr int GLOBAL_POOLING_BLOCK_KIND = 2; constexpr int NESTED_BOTTLENECK_BLOCK_KIND = 3; +constexpr int TRANSFORMER_ATTENTION_BLOCK_KIND = 4; +constexpr int TRANSFORMER_FFN_BLOCK_KIND = 5; // ============================================================================ // Layer Descriptors @@ -98,6 +105,25 @@ struct MatBiasLayerDesc { std::vector weights; // Shape: [num_channels] }; +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +struct TransformerRMSNormDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + std::vector weight; // Shape: [num_channels] +}; + +/// Full-featured RMSNorm (gamma/beta, spatial mode) used at the trunk tip. +struct RMSNormLayerDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + bool spatial = false; + int cgroup_size = 0; + std::vector gamma; // Shape: [num_channels] + std::vector beta; // Shape: [num_channels] +}; + // ============================================================================ // Block Descriptors // ============================================================================ @@ -106,12 +132,16 @@ struct MatBiasLayerDesc { struct ResidualBlockDesc; struct GlobalPoolingResidualBlockDesc; struct NestedBottleneckResidualBlockDesc; +struct TransformerAttentionBlockDesc; +struct TransformerFFNBlockDesc; /// Block descriptor variant using BlockDesc = std::variant< ResidualBlockDesc, GlobalPoolingResidualBlockDesc, - NestedBottleneckResidualBlockDesc + NestedBottleneckResidualBlockDesc, + TransformerAttentionBlockDesc, + TransformerFFNBlockDesc >; /// Block with its kind @@ -165,6 +195,38 @@ struct NestedBottleneckResidualBlockDesc { ConvLayerDesc post_conv; }; +/// Transformer self-attention block descriptor (pre-norm, multi-head, optional 2D RoPE, GQA). +struct TransformerAttentionBlockDesc { + std::string name; + int num_heads = 0; + int num_kv_heads = 0; + int q_head_dim = 0; + int v_head_dim = 0; + bool use_rope = false; + bool learnable_rope = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc q_proj; + MatMulLayerDesc k_proj; + MatMulLayerDesc v_proj; + MatMulLayerDesc out_proj; + int rope_num_kv_heads = 0; + int rope_num_pairs = 0; + std::vector rope_freqs; // learnable: (num_kv_heads, num_pairs, 2) flattened + float rope_theta = 0.0f; +}; + +/// Transformer feed-forward (SwiGLU) block descriptor. +struct TransformerFFNBlockDesc { + std::string name; + int num_channels = 0; + int ffn_channels = 0; + bool use_swiglu = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc linear1; + MatMulLayerDesc linear_gate; // only used when use_swiglu + MatMulLayerDesc linear2; +}; + // ============================================================================ // SGF Metadata Encoder (v15+) // ============================================================================ @@ -202,7 +264,9 @@ struct TrunkDesc { MatMulLayerDesc initial_matmul; std::optional sgf_metadata_encoder; std::vector blocks; + int trunk_norm_kind = TRUNK_NORM_KIND_STANDARD; BatchNormLayerDesc trunk_tip_bn; + RMSNormLayerDesc trunk_tip_rms_norm; ActivationLayerDesc trunk_tip_activation; }; diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 77a2d45c9..6a1cecabc 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -130,6 +130,8 @@ ActivationKind activationLayerDescToSwift(const ActivationLayerDesc* desc) { return ActivationKind::mish(); case ACTIVATION_MISH_SCALE8: return ActivationKind::identity(); // Metal/CoreML does not use scaled mish + case ACTIVATION_SILU: + return ActivationKind::silu(); case ACTIVATION_IDENTITY: return ActivationKind::identity(); default: @@ -217,6 +219,58 @@ SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(con postConv); } +/// Convert a transformer RMSNorm description from C++ to Swift +SWTransformerRMSNormDesc transformerRMSNormDescToSwift(const TransformerRMSNormDesc* desc) { + return createSWTransformerRMSNormDesc( + desc->numChannels, + desc->epsilon, + (float*)desc->weight.data()); +} + +/// Convert a transformer attention block description from C++ to Swift +SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const TransformerAttentionDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc qProj = matMulLayerDescToSwift(&desc->qProj); + SWMatMulLayerDesc kProj = matMulLayerDescToSwift(&desc->kProj); + SWMatMulLayerDesc vProj = matMulLayerDescToSwift(&desc->vProj); + SWMatMulLayerDesc outProj = matMulLayerDescToSwift(&desc->outProj); + float* ropeFreqs = desc->ropeFreqs.empty() ? nullptr : (float*)desc->ropeFreqs.data(); + + return createSWTransformerAttentionBlockDesc( + desc->numHeads, + desc->numKVHeads, + desc->qHeadDim, + desc->vHeadDim, + desc->useRope, + desc->learnableRope, + preLN, + qProj, + kProj, + vProj, + outProj, + desc->ropeNumKVHeads, + desc->ropeNumPairs, + ropeFreqs, + desc->ropeTheta); +} + +/// Convert a transformer FFN block description from C++ to Swift +SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); + SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); + SWMatMulLayerDesc linear2 = matMulLayerDescToSwift(&desc->linear2); + + return createSWTransformerFFNBlockDesc( + desc->numChannels, + desc->ffnChannels, + desc->useSwiGLU, + preLN, + linear1, + linearGate, + linear2); +} + /// Convert residual blocks from C++ to Swift swift::Array residualBlocksToSwift(const vector>& blocks) { auto builder = createBlockDescriptorBuilder(); @@ -230,9 +284,12 @@ swift::Array residualBlocksToSwift(const vector sGFMetadataEncoderDescToSwift(const SG } /// Convert a trunk description from C++ to Swift +SWRMSNormLayerDesc rmsNormLayerDescToSwift(const RMSNormLayerDesc* desc) { + float* gamma = desc->gamma.empty() ? nullptr : (float*)desc->gamma.data(); + float* beta = desc->beta.empty() ? nullptr : (float*)desc->beta.data(); + return createSWRMSNormLayerDesc( + desc->numChannels, + desc->epsilon, + desc->spatial, + gamma, + beta); +} + SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { SWConvLayerDesc initialConv = convLayerDescToSwift(&trunk->initialConv); SWMatMulLayerDesc initialMatMul = matMulLayerDescToSwift(&trunk->initialMatMul); auto sgfMetadataEncoder = sGFMetadataEncoderDescToSwift(&trunk->sgfMetadataEncoder); auto swBlocks = residualBlocksToSwift(trunk->blocks); - if(trunk->trunkNormKind != TRUNK_NORM_KIND_STANDARD) - throw StringError("Trunk RMSNorm is not yet supported by the Metal backend"); SWBatchNormLayerDesc trunkTipBN = batchNormLayerDescToSwift(&trunk->trunkTipBN); + SWRMSNormLayerDesc trunkTipRMSNorm = rmsNormLayerDescToSwift(&trunk->trunkTipRMSNorm); ActivationKind trunkTipActivation = activationLayerDescToSwift(&trunk->trunkTipActivation); return createSWTrunkDesc( @@ -285,7 +352,9 @@ SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { initialMatMul, sgfMetadataEncoder, swBlocks, + trunk->trunkNormKind, trunkTipBN, + trunkTipRMSNorm, trunkTipActivation); } diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index bbd2255bc..f275c7925 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -76,6 +76,11 @@ extension MPSGraph { return mulTensor } + + /// SiLU / Swish activation: x * sigmoid(x). Numerically stable across FP16/FP32. + func silu(tensor: MPSGraphTensor) -> MPSGraphTensor { + return multiplication(tensor, sigmoid(with: tensor, name: nil), name: nil) + } } // MARK: - Input Shape Utilities @@ -358,6 +363,7 @@ public enum ActivationKind { case identity case relu case mish + case silu } /// A struct that represents a description of convolutional layer. @@ -487,6 +493,63 @@ public func createSWMatBiasLayerDesc( weights: weights) } +/// A lightweight RMSNorm description used inside transformer blocks (weight only, no bias). +public struct SWTransformerRMSNormDesc { + let numChannels: NSNumber + let epsilon: Float + let weight: UnsafeMutablePointer + + init(numChannels: NSNumber, epsilon: Float, weight: UnsafeMutablePointer) { + self.numChannels = numChannels + self.epsilon = epsilon + self.weight = weight + } +} + +public func createSWTransformerRMSNormDesc( + numChannels: Int32, + epsilon: Float, + weight: UnsafeMutablePointer +) -> SWTransformerRMSNormDesc { + return SWTransformerRMSNormDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + weight: weight) +} + +/// A full-featured RMSNorm description (gamma/beta, spatial mode), used at the trunk tip. +public struct SWRMSNormLayerDesc { + let numChannels: NSNumber + let epsilon: Float + let spatial: Bool + let gamma: UnsafeMutablePointer? + let beta: UnsafeMutablePointer? + + init(numChannels: NSNumber, epsilon: Float, spatial: Bool, + gamma: UnsafeMutablePointer?, beta: UnsafeMutablePointer?) { + self.numChannels = numChannels + self.epsilon = epsilon + self.spatial = spatial + self.gamma = gamma + self.beta = beta + } +} + +public func createSWRMSNormLayerDesc( + numChannels: Int32, + epsilon: Float, + spatial: Bool, + gamma: UnsafeMutablePointer?, + beta: UnsafeMutablePointer? +) -> SWRMSNormLayerDesc { + return SWRMSNormLayerDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + spatial: spatial, + gamma: gamma, + beta: beta) +} + // MARK: - Core Layers /// A class that represents a convolutional layer using MPSGraph @@ -612,6 +675,8 @@ struct ActivationLayer { resultTensor = graph.reLU(with: sourceTensor, name: nil) case .mish: resultTensor = graph.mish(tensor: sourceTensor) + case .silu: + resultTensor = graph.silu(tensor: sourceTensor) default: resultTensor = sourceTensor } @@ -987,6 +1052,140 @@ public func createSWNestedBottleneckResidualBlockDesc( postConv: postConv) } +public class SWTransformerAttentionBlockDesc: BlockDescriptor { + let numHeads: Int + let numKVHeads: Int + let qHeadDim: Int + let vHeadDim: Int + let useRope: Bool + let learnableRope: Bool + let preLN: SWTransformerRMSNormDesc + let qProj: SWMatMulLayerDesc + let kProj: SWMatMulLayerDesc + let vProj: SWMatMulLayerDesc + let outProj: SWMatMulLayerDesc + let ropeNumKVHeads: Int + let ropeNumPairs: Int + let ropeFreqs: UnsafeMutablePointer? // learnable: (numKVHeads, numPairs, 2) flattened + let ropeTheta: Float + + init( + numHeads: Int, + numKVHeads: Int, + qHeadDim: Int, + vHeadDim: Int, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int, + ropeNumPairs: Int, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float + ) { + self.numHeads = numHeads + self.numKVHeads = numKVHeads + self.qHeadDim = qHeadDim + self.vHeadDim = vHeadDim + self.useRope = useRope + self.learnableRope = learnableRope + self.preLN = preLN + self.qProj = qProj + self.kProj = kProj + self.vProj = vProj + self.outProj = outProj + self.ropeNumKVHeads = ropeNumKVHeads + self.ropeNumPairs = ropeNumPairs + self.ropeFreqs = ropeFreqs + self.ropeTheta = ropeTheta + } +} + +public func createSWTransformerAttentionBlockDesc( + numHeads: Int32, + numKVHeads: Int32, + qHeadDim: Int32, + vHeadDim: Int32, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int32, + ropeNumPairs: Int32, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float +) -> SWTransformerAttentionBlockDesc { + return SWTransformerAttentionBlockDesc( + numHeads: Int(numHeads), + numKVHeads: Int(numKVHeads), + qHeadDim: Int(qHeadDim), + vHeadDim: Int(vHeadDim), + useRope: useRope, + learnableRope: learnableRope, + preLN: preLN, + qProj: qProj, + kProj: kProj, + vProj: vProj, + outProj: outProj, + ropeNumKVHeads: Int(ropeNumKVHeads), + ropeNumPairs: Int(ropeNumPairs), + ropeFreqs: ropeFreqs, + ropeTheta: ropeTheta) +} + +public class SWTransformerFFNBlockDesc: BlockDescriptor { + let numChannels: Int + let ffnChannels: Int + let useSwiGLU: Bool + let preLN: SWTransformerRMSNormDesc + let linear1: SWMatMulLayerDesc + let linearGate: SWMatMulLayerDesc + let linear2: SWMatMulLayerDesc + + init( + numChannels: Int, + ffnChannels: Int, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc + ) { + self.numChannels = numChannels + self.ffnChannels = ffnChannels + self.useSwiGLU = useSwiGLU + self.preLN = preLN + self.linear1 = linear1 + self.linearGate = linearGate + self.linear2 = linear2 + } +} + +public func createSWTransformerFFNBlockDesc( + numChannels: Int32, + ffnChannels: Int32, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc +) -> SWTransformerFFNBlockDesc { + return SWTransformerFFNBlockDesc( + numChannels: Int(numChannels), + ffnChannels: Int(ffnChannels), + useSwiGLU: useSwiGLU, + preLN: preLN, + linear1: linear1, + linearGate: linearGate, + linear2: linear2) +} + public class BlockDescriptorBuilder { public var blockDescriptors: [BlockDescriptor] = [] @@ -1001,6 +1200,316 @@ public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { return BlockDescriptorBuilder() } +// MARK: - Transformer Layers + +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +/// Input/output are NCHW [B, C, H, W]. Normalizes across channels per spatial position, +/// scales by per-channel weight, and masks the output. +struct TransformerRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerRMSNormDesc + ) { + let numChannels = descriptor.numChannels + let dataType = sourceTensor.dataType + + // meanSq over channel axis (1): [B,1,H,W] + let sq = graph.square(with: sourceTensor, name: nil) + let sumSq = graph.reductionSum(with: sq, axis: 1, name: nil) + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + let meanSq = graph.multiplication(sumSq, invC, name: nil) + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + // scale by per-channel weight [1, C, 1, 1] + let weightShape: [NSNumber] = [1, numChannels, 1, 1] + let weightData = Data(floatsNoCopy: descriptor.weight, shape: weightShape) + let weightTensor = graph.constant(weightData, shape: weightShape, dataType: dataType) + let scaled = graph.multiplication(normalized, weightTensor, name: nil) + + resultTensor = graph.multiplication(scaled, maskTensor, name: nil) + } +} + +/// Full-featured RMSNorm for the trunk tip: gamma/beta, spatial or per-position mode, and a +/// fused activation. Input/output are NCHW [B, C, H, W]. Mirrors the Eigen RMSNormLayer. +struct TrunkRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWRMSNormLayerDesc, + activationKind: ActivationKind + ) { + let dataType = sourceTensor.dataType + let numChannels = descriptor.numChannels + + // Zero invalid positions before accumulating sum of squares. + let masked = graph.multiplication(sourceTensor, maskTensor, name: nil) + let sq = graph.square(with: masked, name: nil) + + let meanSq: MPSGraphTensor + if descriptor.spatial { + // Normalize over channels AND valid spatial positions per batch element. + let sumSq = graph.reductionSum(with: sq, axes: [1, 2, 3], name: nil) // [B,1,1,1] + let count = graph.reductionSum(with: maskTensor, axes: [1, 2, 3], name: nil) // valid positions + let cTensor = graph.constant(numChannels.doubleValue, dataType: dataType) + let totalElts = graph.multiplication(count, cTensor, name: nil) + meanSq = graph.division(sumSq, totalElts, name: nil) + } else { + // Per-position normalization across channels. + let sumSq = graph.reductionSum(with: sq, axes: [1], name: nil) // [B,1,H,W] + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + meanSq = graph.multiplication(sumSq, invC, name: nil) + } + + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + let gammaShape: [NSNumber] = [1, numChannels, 1, 1] + let gammaTensor = graph.constant(Data(floatsNoCopy: descriptor.gamma!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let betaTensor = graph.constant(Data(floatsNoCopy: descriptor.beta!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let scaled = graph.addition(graph.multiplication(normalized, gammaTensor, name: nil), betaTensor, name: nil) + + let activated = ActivationLayer(graph: graph, sourceTensor: scaled, activationKind: activationKind).resultTensor + resultTensor = graph.multiplication(activated, maskTensor, name: nil) + } +} + +/// A transformer self-attention block (pre-norm, multi-head, optional 2D RoPE, GQA). +/// Mirrors the Eigen reference: RMSNorm -> Q/K/V projections -> RoPE -> scaled dot-product +/// attention with masked softmax -> output projection -> masked residual. +/// Tensors are NCHW [B, C, H, W]; spatial positions (H*W, ordered y*W+x) are the sequence. +struct TransformerAttentionBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerAttentionBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let dataType = sourceTensor.dataType + let numHeads = descriptor.numHeads + let numKVHeads = descriptor.numKVHeads + let qHeadDim = descriptor.qHeadDim + let vHeadDim = descriptor.vHeadDim + let nnX = nnXLen.intValue + let nnY = nnYLen.intValue + let seq = nnX * nnY + + // 1. RMSNorm (NCHW) + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + + // To NHWC [B,H,W,C] so that reshape [-1, C] groups channels per position. + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. Q/K/V projections via matmul over channels -> [B*seq, heads*dim] + let q = MatMulLayer(graph: graph, descriptor: descriptor.qProj, sourceTensor: normedNHWC).resultTensor + let k = MatMulLayer(graph: graph, descriptor: descriptor.kProj, sourceTensor: normedNHWC).resultTensor + let v = MatMulLayer(graph: graph, descriptor: descriptor.vProj, sourceTensor: normedNHWC).resultTensor + + // 3. reshape to [B, heads, seq, dim] + var qh = TransformerAttentionBlock.toHeads(graph, q, seq: seq, numHeads: numHeads, headDim: qHeadDim) + var kh = TransformerAttentionBlock.toHeads(graph, k, seq: seq, numHeads: numKVHeads, headDim: qHeadDim) + let vh = TransformerAttentionBlock.toHeads(graph, v, seq: seq, numHeads: numKVHeads, headDim: vHeadDim) + + // 4. RoPE on Q and K + if descriptor.useRope { + let numPairs = qHeadDim / 2 + // Q heads map to KV heads via kvh = h * numKVHeads / numHeads (matches Eigen). + let (qCos, qSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in (h * numKVHeads) / numHeads }) + let (kCos, kSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numKVHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in h }) + qh = TransformerAttentionBlock.applyRope(graph, qh, cosT: qCos, sinT: qSin, + numHeads: numHeads, seq: seq, numPairs: numPairs) + kh = TransformerAttentionBlock.applyRope(graph, kh, cosT: kCos, sinT: kSin, + numHeads: numKVHeads, seq: seq, numPairs: numPairs) + } + + // GQA: if numKVHeads < numHeads, repeat KV heads so they align with query heads. + var khExp = kh + var vhExp = vh + if numKVHeads != numHeads { + let groupSize = numHeads / numKVHeads + khExp = TransformerAttentionBlock.repeatKVHeads(graph, kh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: qHeadDim) + vhExp = TransformerAttentionBlock.repeatKVHeads(graph, vh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: vHeadDim) + } + + // 5. scores = scale * Q @ K^T -> [B, heads, seq, seq] + let khT = graph.transpose(khExp, permutation: [0, 1, 3, 2], name: nil) + var scores = graph.matrixMultiplication(primary: qh, secondary: khT, name: nil) + let scale = graph.constant(1.0 / Double(qHeadDim).squareRoot(), dataType: dataType) + scores = graph.multiplication(scores, scale, name: nil) + + // Mask keys: add (maskKey - 1) * BIG so masked key columns get ~ -inf before softmax. + // maskTensor [B,1,H,W] -> [B,1,1,seq] + let maskNHWC = graph.transpose(maskTensor, permutation: [0, 2, 3, 1], name: nil) // [B,H,W,1] + let maskSeq = graph.reshape(maskNHWC, shape: [-1, 1, 1, seq as NSNumber], name: nil) + let one = graph.constant(1.0, dataType: dataType) + let big = graph.constant(1.0e9, dataType: dataType) + let keyBias = graph.multiplication(graph.subtraction(maskSeq, one, name: nil), big, name: nil) + scores = graph.addition(scores, keyBias, name: nil) + + // 6. softmax over key axis (last) + let attn = graph.softMax(with: scores, axis: 3, name: nil) + + // 7. out = attn @ V -> [B, heads, seq, vHeadDim] + let attnOut = graph.matrixMultiplication(primary: attn, secondary: vhExp, name: nil) + + // 8. back to [B*seq, heads*vHeadDim] + let outHeadsLast = graph.transpose(attnOut, permutation: [0, 2, 1, 3], name: nil) // [B,seq,heads,vHeadDim] + let outFlat = graph.reshape(outHeadsLast, shape: [-1, (numHeads * vHeadDim) as NSNumber], name: nil) + + // 9. output projection -> [B*seq, C] + let proj = MatMulLayer(graph: graph, descriptor: descriptor.outProj, sourceTensor: outFlat).resultTensor + + // 10. reshape to NHWC then NCHW + let outChannels = descriptor.outProj.outChannels + let projNHWC = graph.reshape(proj, shape: [-1, nnYLen, nnXLen, outChannels], name: nil) + let projNCHW = graph.transpose(projNHWC, permutation: [0, 3, 1, 2], name: nil) + + // 11. masked residual + let masked = graph.multiplication(projNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } + + /// Reshape [B*seq, numHeads*headDim] -> [B, numHeads, seq, headDim]. + static func toHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, seq: Int, numHeads: Int, headDim: Int) -> MPSGraphTensor { + let reshaped = graph.reshape(x, shape: [-1, seq as NSNumber, numHeads as NSNumber, headDim as NSNumber], name: nil) + return graph.transpose(reshaped, permutation: [0, 2, 1, 3], name: nil) + } + + /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. + static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { + // Insert a group axis then broadcast: [B,kv,1,seq,dim] -> [B,kv,group,seq,dim] -> [B,kv*group,seq,dim] + let expanded = graph.reshape(x, shape: [-1, numKVHeads as NSNumber, 1, seq as NSNumber, headDim as NSNumber], name: nil) + let targetShape: [NSNumber] = [-1, numKVHeads as NSNumber, groupSize as NSNumber, seq as NSNumber, headDim as NSNumber] + let broadcast = graph.broadcast(expanded, shape: targetShape, name: nil) + return graph.reshape(broadcast, shape: [-1, (numKVHeads * groupSize) as NSNumber, seq as NSNumber, headDim as NSNumber], name: nil) + } + + /// Apply interleaved-pair RoPE to [B, nHeads, seq, headDim] using cos/sin tables [1,nHeads,seq,numPairs]. + static func applyRope(_ graph: MPSGraph, _ x: MPSGraphTensor, cosT: MPSGraphTensor, sinT: MPSGraphTensor, + numHeads: Int, seq: Int, numPairs: Int) -> MPSGraphTensor { + let pairsShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 2] + let xPairs = graph.reshape(x, shape: pairsShape, name: nil) + let evenShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber] + let xEven = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 0, length: 1, name: nil), shape: evenShape, name: nil) + let xOdd = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 1, length: 1, name: nil), shape: evenShape, name: nil) + let outEven = graph.subtraction(graph.multiplication(xEven, cosT, name: nil), graph.multiplication(xOdd, sinT, name: nil), name: nil) + let outOdd = graph.addition(graph.multiplication(xEven, sinT, name: nil), graph.multiplication(xOdd, cosT, name: nil), name: nil) + let pairShape5: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 1] + let outEvenE = graph.reshape(outEven, shape: pairShape5, name: nil) + let outOddE = graph.reshape(outOdd, shape: pairShape5, name: nil) + let stacked = graph.concatTensors([outEvenE, outOddE], dimension: 4, name: nil) + return graph.reshape(stacked, shape: [-1, numHeads as NSNumber, seq as NSNumber, (numPairs * 2) as NSNumber], name: nil) + } + + /// Build RoPE cos/sin constant tensors of shape [1, nHeads, seq, numPairs]. + static func makeRopeTables(_ graph: MPSGraph, descriptor: SWTransformerAttentionBlockDesc, + nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, + dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { + let count = nHeads * seq * numPairs + let cosBuf = UnsafeMutablePointer.allocate(capacity: count) + let sinBuf = UnsafeMutablePointer.allocate(capacity: count) + let numPairsPerDim = numPairs / 2 + let dimHalf = qHeadDim / 2 + for h in 0.. SiLU(linear1)*gate -> linear2 -> masked residual. +struct TransformerFFNBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerFFNBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let numChannels = descriptor.numChannels + + // 1. RMSNorm + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. linear1 + gate, both [B*seq, ffnChannels] + let a = MatMulLayer(graph: graph, descriptor: descriptor.linear1, sourceTensor: normedNHWC).resultTensor + let gate = MatMulLayer(graph: graph, descriptor: descriptor.linearGate, sourceTensor: normedNHWC).resultTensor + + // 3. SwiGLU: SiLU(a) * gate, SiLU(a) = a * sigmoid(a) + let siluA = graph.multiplication(a, graph.sigmoid(with: a, name: nil), name: nil) + let h = graph.multiplication(siluA, gate, name: nil) + + // 4. linear2 -> [B*seq, numChannels] + let out = MatMulLayer(graph: graph, descriptor: descriptor.linear2, sourceTensor: h).resultTensor + + // 5. reshape to NHWC then NCHW, masked residual + let outNHWC = graph.reshape(out, shape: [-1, nnYLen, nnXLen, numChannels as NSNumber], name: nil) + let outNCHW = graph.transpose(outNHWC, permutation: [0, 3, 1, 2], name: nil) + let masked = graph.multiplication(outNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } +} + // MARK: - Block Implementations /// A class that represents a Residual Block layer @@ -1241,6 +1750,26 @@ struct BlockStack { optimizeIdentityMask: optimizeIdentityMask) blockInput = ordinary.resultTensor + case let attnDescriptor as SWTransformerAttentionBlockDesc: + let attn = TransformerAttentionBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: attnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = attn.resultTensor + case let ffnDescriptor as SWTransformerFFNBlockDesc: + let ffn = TransformerFFNBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: ffnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = ffn.resultTensor default: blockInput = sourceTensor } @@ -1483,7 +2012,9 @@ public class SWTrunkDesc { let initialMatMul: SWMatMulLayerDesc let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? let blockDescriptors: [BlockDescriptor] + let trunkNormKind: Int let trunkTipBN: SWBatchNormLayerDesc + let trunkTipRMSNorm: SWRMSNormLayerDesc let trunkTipActivation: ActivationKind init( @@ -1496,7 +2027,9 @@ public class SWTrunkDesc { initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) { self.version = version @@ -1508,7 +2041,9 @@ public class SWTrunkDesc { self.initialMatMul = initialMatMul self.sgfMetadataEncoder = sgfMetadataEncoder self.blockDescriptors = blockDescriptors + self.trunkNormKind = trunkNormKind self.trunkTipBN = trunkTipBN + self.trunkTipRMSNorm = trunkTipRMSNorm self.trunkTipActivation = trunkTipActivation } } @@ -1523,7 +2058,9 @@ public func createSWTrunkDesc( initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int32, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) -> SWTrunkDesc { return SWTrunkDesc( @@ -1536,7 +2073,9 @@ public func createSWTrunkDesc( initialMatMul: initialMatMul, sgfMetadataEncoder: sgfMetadataEncoder, blockDescriptors: blockDescriptors, + trunkNormKind: Int(trunkNormKind), trunkTipBN: trunkTipBN, + trunkTipRMSNorm: trunkTipRMSNorm, trunkTipActivation: trunkTipActivation) } @@ -1632,21 +2171,34 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - let trunkTipBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.trunkTipBN, - nnXLen: nnXLen, - nnYLen: nnYLen, - optimizeIdentityMask: optimizeIdentityMask) + // TRUNK_NORM_KIND_RMSNORM == 1: trunk tip uses RMSNorm with a fused activation. + // Otherwise (standard): BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == 1 { + let trunkTipRMSNorm = TrunkRMSNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipRMSNorm, + activationKind: descriptor.trunkTipActivation) - let trunkTipActivation = ActivationLayer( - graph: graph, - sourceTensor: trunkTipBN.resultTensor, - activationKind: descriptor.trunkTipActivation) + resultTensor = trunkTipRMSNorm.resultTensor + } else { + let trunkTipBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipActivation = ActivationLayer( + graph: graph, + sourceTensor: trunkTipBN.resultTensor, + activationKind: descriptor.trunkTipActivation) - resultTensor = trunkTipActivation.resultTensor + resultTensor = trunkTipActivation.resultTensor + } assert(resultTensor.shape?.count == 4) } From 6f8314b7f6634afc0356a2ad5a1a061a48e2aaa9 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:43:20 +0800 Subject: [PATCH 09/18] Fix Metal GPU crash on GQA transformer attention (NDArray INT_MAX) GQA models (numKVHeads != numHeads, e.g. b7c96h6kv3qk32v16tflrs) crashed on the Metal GPU path with: MPSNDArray.mm: NDArray dimension length > INT_MAX repeatKVHeads expanded the KV heads via reshape -> broadcast -> reshape, passing -1 for the batch dim in the broadcast target shape. Unlike reshape, MPSGraph.broadcast(_:shape:) does not infer -1 and treats it as a literal (near-INT_MAX) dimension, tripping the NDArray assertion. Replace the broadcast with a shape-safe slice + concat: slice each KV head (dim 1) and concatenate groupSize copies consecutively, so query head h uses kv = h / groupSize, matching the Eigen reference (kvh = h / kvGroupSize). No -1 broadcast. Verified: testgpuerror GPU vs Eigen reference at 9/13/19 now passes (~0.00003% winrate); non-GQA models (incl. b10c384h6) unaffected since the GQA branch is gated on numKVHeads != numHeads. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index f275c7925..dc4d22d53 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -1401,11 +1401,19 @@ struct TransformerAttentionBlock { /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { - // Insert a group axis then broadcast: [B,kv,1,seq,dim] -> [B,kv,group,seq,dim] -> [B,kv*group,seq,dim] - let expanded = graph.reshape(x, shape: [-1, numKVHeads as NSNumber, 1, seq as NSNumber, headDim as NSNumber], name: nil) - let targetShape: [NSNumber] = [-1, numKVHeads as NSNumber, groupSize as NSNumber, seq as NSNumber, headDim as NSNumber] - let broadcast = graph.broadcast(expanded, shape: targetShape, name: nil) - return graph.reshape(broadcast, shape: [-1, (numKVHeads * groupSize) as NSNumber, seq as NSNumber, headDim as NSNumber], name: nil) + // Repeat each KV head groupSize times consecutively so query head h uses kv = h / groupSize, + // matching the Eigen reference (kvh = h / kvGroupSize). We slice each KV head and concat the + // copies along the head axis. Note: MPSGraph.broadcast(_:shape:) does NOT infer -1, so a + // reshape+broadcast approach with a dynamic batch dim triggers an NDArray INT_MAX assertion; + // slice+concat is shape-safe with no -1 broadcast. + var heads: [MPSGraphTensor] = [] + heads.reserveCapacity(numKVHeads * groupSize) + for kv in 0.. Date: Mon, 1 Jun 2026 15:43:33 +0800 Subject: [PATCH 10/18] Fix dropped SiLU activation in CoreML value/policy/meta heads The MIL builder's inline activation dispatch (buildValueHead v2, policy-head pass activation, and both SGF metadata encoder layers) handled only ReLU and Mish; SiLU silently fell through to the else branch and applied NO activation at all. This corrupted the value-head pool -> v2 -> v3 scalar path for every SiLU model, producing large errors in winrate/score/lead while ownership (which branches off v1, before v2) stayed correct. Add an ActivationType::Silu branch (addSiluOps) at all four sites. The generic conv/BN activation path already handled SiLU, which is why the trunk and v1/ownership were fine. Root-caused via systematic debugging: CoreML-CPU(fp32) error was identical to ANE (-> logical bug, not fp16), and perfect ownership with wrong scalars localized it to the value-head post-pooling path. This corrects the earlier "ANE is fp16-precision-limited (~5%)" conclusion -- that 5.66% on b10c384h6 was this bug. After the fix, testgpuerror ANE vs Eigen drops to GPU-level accuracy for all models: b10c384h6 5.66% -> ~0.00005-0.0002% cnorm 11-13% -> ~0.00007% rsnh 22-29% -> ~0.00004-0.0001% Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/external/katagocoreml/src/builder/MILBuilder.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index 2a0bbf44a..3305737cb 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -2593,6 +2593,8 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, pass_activated, ph.gpool_to_pass_mul.out_channels); } else if (ph.pass_activation->activation_type == ActivationType::Mish) { addMishOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); + } else if (ph.pass_activation->activation_type == ActivationType::Silu) { + addSiluOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); } else { pass_activated = pass_biased; } @@ -2640,6 +2642,8 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, v2, vh.v2_mul.out_channels); } else if (vh.v2_activation.activation_type == ActivationType::Mish) { addMishOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); + } else if (vh.v2_activation.activation_type == ActivationType::Silu) { + addSiluOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); } else { v2 = v2_bias; } @@ -2676,6 +2680,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act1, encoder.mul1.out_channels); } else if (encoder.act1.activation_type == ActivationType::Mish) { addMishOps(block, bias1, act1, 2, encoder.mul1.out_channels); + } else if (encoder.act1.activation_type == ActivationType::Silu) { + addSiluOps(block, bias1, act1, 2, encoder.mul1.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); @@ -2698,6 +2704,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act2, encoder.mul2.out_channels); } else if (encoder.act2.activation_type == ActivationType::Mish) { addMishOps(block, bias2, act2, 2, encoder.mul2.out_channels); + } else if (encoder.act2.activation_type == ActivationType::Silu) { + addSiluOps(block, bias2, act2, 2, encoder.mul2.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); From 792c4760628dbed29aa1d6d120494ff7643d7c00 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:23:17 +0800 Subject: [PATCH 11/18] Implement GQA support in CoreML/ANE MIL attention builder The CoreML MIL builder threw "GQA (numKVHeads != numHeads) not supported" for grouped-query-attention transformer models, while the Metal GPU (MPSGraph) path already handled GQA. Port that support to the MIL builder. In buildTransformerAttentionBlock, remove the throw guard and, after the RoPE block and before the scores matmul, repeat each KV head groupSize (= numHeads/numKVHeads) times along the head axis via slice_by_size + concat (interleave=false), so query head h consumes kv head h/groupSize. This matches the Eigen reference (kvh = h/kvGroupSize) and the GPU repeatKVHeads ordering. RoPE stays before the repeat (its cos/sin tables are numKVHeads-shaped). The block is gated by numKVHeads != numHeads, so the standard MHA path is unchanged. Verified on b7c96h6kv3qk32v16tflrs-fson-bnh (6 query / 3 KV heads, qk32/v16) vs Eigen reference: ANE testgpuerror 9/13/19 = 0.00002-0.00003% winrate (previously a hard throw); GPU unchanged; non-GQA model ANE error identical to pre-change; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 52 +++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index 3305737cb..b90da55b6 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -1891,10 +1891,6 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; - if (numKVHeads != numHeads) { - throw std::runtime_error(desc.name + ": GQA (numKVHeads != numHeads) not supported in CoreML backend"); - } - auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, const std::vector& dims) { std::string shapeName = out + "_shape"; @@ -2024,6 +2020,54 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI kh = applyRope(kh, numKVHeads, "k"); } + // GQA: when numKVHeads < numHeads, repeat each KV head groupSize times along the head + // axis (axis 1) so query head h consumes kv head (h / groupSize). RoPE has already been + // applied above to the unexpanded kh (kh = applyRope(kh, numKVHeads, "k")), mirroring the + // GPU path (metallayers.swift repeatKVHeads runs AFTER applyRope). We slice each KV head + // and concat its copies consecutively, so the resulting head index is kv*groupSize + g; + // query head h then maps to kv = h/groupSize == (h*numKVHeads)/numHeads (exact divisor, + // the same formula the qh RoPE table uses) == Eigen's kvh = h/kvGroupSize. slice_by_size + + // concat (not reshape+broadcast) avoids the dynamic -1 batch broadcast pitfall, same as the + // GPU code. The repeat is required so the scores (qh@kh^T) and attnOut (attn@vh) matmuls see + // matching [B,numHeads,...] batch dims instead of numHeads vs numKVHeads (no broadcast). + if (numKVHeads != numHeads) { + const int groupSize = numHeads / numKVHeads; + auto repeatKVHeads = [&](const std::string& x, const std::string& tag, int headDim) { + std::vector parts; + parts.reserve(static_cast(numKVHeads) * groupSize); + for (int kv = 0; kv < numKVHeads; kv++) { + for (int g = 0; g < groupSize; g++) { + std::string part = genVarName(prefix + "_" + tag + "_slc"); + std::string beginName = part + "_begin", sizeName = part + "_size"; + addIntArrayConstOp(block, beginName, {0, kv, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, headDim}); + auto* sop = block->add_operations(); + sop->set_type("slice_by_size"); + (*sop->mutable_inputs())["x"].add_arguments()->set_name(x); + (*sop->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*sop->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(sop, part, {-1, 1, seq, headDim}); + parts.push_back(part); + } + } + std::string out = genVarName(prefix + "_" + tag + "_exp"); + std::string axisName = out + "_axis", interleaveName = out + "_interleave"; + addIntScalarConstOp(block, axisName, 1); + addBoolScalarConstOp(block, interleaveName, false); + auto* cop = block->add_operations(); + cop->set_type("concat"); + auto& cin = *cop->mutable_inputs(); + for (const std::string& part : parts) + cin["values"].add_arguments()->set_name(part); + cin["axis"].add_arguments()->set_name(axisName); + cin["interleave"].add_arguments()->set_name(interleaveName); + setShape(cop, out, {-1, numHeads, seq, headDim}); + return out; + }; + kh = repeatKVHeads(kh, "khrep", qHeadDim); + vh = repeatKVHeads(vh, "vhrep", vHeadDim); + } + std::string scores = genVarName(prefix + "_scores"); matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); std::string scaleName = prefix + "_scale"; From 3839e529160e103d4120d879395c126dcb57ae0c Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 07:07:55 +0800 Subject: [PATCH 12/18] Fix CoreML/ANE FP16 transformer accuracy via precision tiers Transformer models failed testgpuerror on the CoreML/ANE FP16 path: the ANE accumulates FP16 matmuls AND convs in FP16 (unlike OpenCL/CUDA/TRT, which accumulate in FP32), so wide/deep transformers lose too much precision and miss the thresholds at larger board sizes. BF16 is not an option (no compute path in CoreML: cast op, ArrayFeatureType and MLMultiArray all lack bf16; coremltools confirms FLOAT16/FLOAT32 only). Follow KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), channel-gated for the ANE since every FP32 op runs off the FP16-only ANE: - RMSNorm reduction cores: FP32 in FP16 mode (always). - Non-spatial (FFN/Q-K-V proj/pooling/matmul): FP32 (always). MIL `linear` needs const weight/bias so it can't runtime-cast; only `matmul` is wrapped. - Convs: FP32 only for wide trunks (>= 320ch); narrower keep convs on-ANE. - Narrow trunks (< 256ch) sit on the testgpuerror thresholds and no partial FP32 config passes all board sizes (islands cast back to FP16 leave a noisy FP16 spatial stream); build them fully FP32 (off-ANE, cheap since small). Weights stay FP16-stored via runtime up-casts, except full-FP32 models. Add per-weight FP32 serialization (WeightEntry.is_fp32) so a const declared FP32 inside an otherwise-FP16 model is stored FP32 (fixes the load-time "storage and type have different number of elements" abort and enables the full-FP32 tier). Also fixes addFloatScalarConstOp keying storage off m_use_fp16 instead of the declared m_weight_dtype. Result: all 4 transformer test models (b10c384h6/b4c256h4/b7c96h3/ b7c96h6kv3-GQA) pass testgpuerror on ANE FP16 at sizes 9/13/19; runtests and runnnlayertests pass. All changes gated on m_use_fp16; FP32 mode unchanged. The 256/320 channel thresholds are width heuristics validated on these models. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 271 +++++++++++++++--- .../katagocoreml/src/builder/MILBuilder.hpp | 16 ++ .../katagocoreml/src/builder/Operations.cpp | 4 +- .../katagocoreml/src/builder/Operations.hpp | 7 +- .../src/serializer/WeightSerializer.cpp | 6 +- 5 files changed, 268 insertions(+), 36 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index b90da55b6..bef0fea73 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -31,7 +31,24 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, ? CoreML::Specification::MILSpec::DataType::FLOAT16 : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) - , m_var_counter(0) {} + , m_var_counter(0) { + // Precision tiers in FP16 mode (the ANE accumulates FP16 in FP16; FP32 ops run off the FP16-only + // ANE). NARROW transformer trunks are unreliable on the FP16 ANE: their policy/value metrics sit + // right on the testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial + // FP32 leaves a noisy FP16 spatial stream). So build narrow trunks FULLY in FP32 (off-ANE, but + // cheap since narrow models are small; correct because it equals the FP32 reference). Weights are + // stored FP32 via per-weight serialization. Wider trunks use partial FP32: non-spatial (matmuls + + // pooling) always FP32; convs FP32 only for very wide trunks (kept on the ANE for narrower ones). + const int trunkChannels = model.trunk.trunk_num_channels; + const bool full_fp32 = use_fp16 && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + if (full_fp32) { + m_use_fp16 = false; + m_use_fp16_io = false; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + m_nonspatial_fp32 = m_use_fp16; + m_conv_fp32 = m_use_fp16 && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; +} void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { auto* dim = tensor_type->add_dimensions(); @@ -213,8 +230,10 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage - m_ops.registerWeight(name, data, shape); + // Register weight for blob storage. Mark FP32 storage when this const is declared FP32 (e.g. + // inside an FP32 sub-region of an otherwise-FP16 model) so storage matches the declared type. + m_ops.registerWeight(name, data, shape, + m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); // Add const operation auto* op = block->add_operations(); @@ -329,7 +348,11 @@ void MILBuilder::addFloatScalarConstOp(CoreML::Specification::MILSpec::Block* bl val_type->set_datatype(m_weight_dtype); val_type->set_rank(0); - if (m_use_fp16) { + // Key the storage format off the DECLARED dtype (m_weight_dtype), not the global m_use_fp16: + // a temporarily-flipped FP32 sub-region (m_weight_dtype=FLOAT32 while m_use_fp16 stays true) + // must store FP32 floats, or CoreML rejects the model ("storage and type have different number + // of elements"). For all non-flipped calls m_weight_dtype tracks m_use_fp16, so this is a no-op. + if (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT16) { // For FP16, use bytes storage with FP16 representation MILBlob::Fp16 fp16_val = MILBlob::Fp16::FromFloat(value); std::string bytes_data(reinterpret_cast(&fp16_val.bytes), sizeof(fp16_val.bytes)); @@ -427,6 +450,42 @@ void MILBuilder::addCastOp(CoreML::Specification::MILSpec::Block* block, } } +std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims) { + std::string out = genVarName(input + "_cast"); + std::string dtName = out + "_dt"; + { + auto* op = block->add_operations(); + op->set_type("const"); + auto& na = (*op->mutable_attributes())["name"]; + na.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + na.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtName); + auto& va = (*op->mutable_attributes())["val"]; + va.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + va.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtype); + auto* o = op->add_outputs(); + o->set_name(dtName); + o->mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + } + auto* op = block->add_operations(); + op->set_type("cast"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(input); + (*op->mutable_inputs())["dtype"].add_arguments()->set_name(dtName); + auto* o = op->add_outputs(); + o->set_name(out); + auto* tt = o->mutable_type()->mutable_tensortype(); + tt->set_datatype(dtype == "fp32" ? CoreML::Specification::MILSpec::DataType::FLOAT32 + : CoreML::Specification::MILSpec::DataType::FLOAT16); + tt->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + if (d < 0) tt->add_dimensions()->mutable_unknown()->set_variadic(false); + else tt->add_dimensions()->mutable_constant()->set_size(d); + } + return out; +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -567,6 +626,21 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, tt->add_dimensions()->mutable_constant()->set_size(4); } + // Channel-gated FP32 convs. The ANE accumulates FP16 convs in FP16, which loses too much + // precision for WIDE trunks and fails testgpuerror at large board sizes (validated: 384ch + // fails, <=256ch is fine FP16-on-ANE). For wide trunks (>= threshold) run convs in FP32 (weights + // cast up at runtime, stored fp16). FP32 convs can't run on the fp16-only ANE, so only the wide + // models that actually need it pay that off-ANE cost; narrow models keep convs on the ANE. + const bool convFp32 = m_conv_fp32; + std::string convX = input, convW = weight_name, convOut = output; + auto savedConvDtype = m_weight_dtype; + if (convFp32) { + convX = castFixed(block, input, "fp32", {-1, layer.in_channels, m_board_y_size, m_board_x_size}); + convW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + convOut = output + "_cf32"; + } + // Add conv operation referencing all const parameters auto* op = block->add_operations(); op->set_type("conv"); @@ -578,12 +652,12 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, inputs["pad"].add_arguments()->set_name(pad_name); inputs["pad_type"].add_arguments()->set_name(pad_type_name); inputs["strides"].add_arguments()->set_name(strides_name); - inputs["weight"].add_arguments()->set_name(weight_name); - inputs["x"].add_arguments()->set_name(input); + inputs["weight"].add_arguments()->set_name(convW); + inputs["x"].add_arguments()->set_name(convX); // Output with dimensions [batch, out_channels, height, width] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(convOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(4); @@ -591,6 +665,11 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + + if (convFp32) { + m_weight_dtype = savedConvDtype; + addCastOp(block, convOut, output, "fp16", {-1, layer.out_channels, m_board_y_size, m_board_x_size}); + } } // Helper: Set output tensor type with 4D shape [batch, C, H, W] @@ -945,23 +1024,38 @@ void MILBuilder::addMatMulOp(CoreML::Specification::MILSpec::Block* block, CoreML::Specification::MILSpec::DataType::BOOL); } + // Non-spatial matmul in FP32 (KataGo FP16 convention; weights cast up at runtime, stored fp16). + std::string mmIn = input, mmW = weight_name, mmOut = output; + auto savedMmDtype = m_weight_dtype; + if (m_nonspatial_fp32) { + mmIn = castFixed(block, input, "fp32", {-1, layer.in_channels}); + mmW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + mmOut = output + "_mmf32"; + } + // Add matmul operation auto* op = block->add_operations(); op->set_type("matmul"); auto& inputs = *op->mutable_inputs(); inputs["transpose_x"].add_arguments()->set_name(transpose_x_name); inputs["transpose_y"].add_arguments()->set_name(transpose_y_name); - inputs["x"].add_arguments()->set_name(input); - inputs["y"].add_arguments()->set_name(weight_name); + inputs["x"].add_arguments()->set_name(mmIn); + inputs["y"].add_arguments()->set_name(mmW); // Output with 2D shape [batch, out_channels] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(mmOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(2); setBatchDimension(out_type); out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); + + if (m_nonspatial_fp32) { + m_weight_dtype = savedMmDtype; + addCastOp(block, mmOut, output, "fp16", {-1, layer.out_channels}); + } } void MILBuilder::addMatBiasOp(CoreML::Specification::MILSpec::Block* block, @@ -1022,7 +1116,9 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, std::vector bias_shape = {static_cast(bias.num_channels)}; addConstOp(block, bias_name, bias.weights, bias_shape); - // Add linear operation + // NOTE: the MIL `linear` op requires const weight/bias, so the runtime-cast-to-FP32 trick can't + // be applied here (unlike `matmul`). Value-head linear stays FP16; if a model ever needs it in + // FP32, rewrite as matmul+add (matmul accepts cast inputs). auto* op = block->add_operations(); op->set_type("linear"); auto& inputs = *op->mutable_inputs(); @@ -1718,8 +1814,20 @@ std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Bl setShape(op, out, dims); }; + // RMSNorm reduction core: square -> mean over channels -> rsqrt. In FP16 mode compute this + // core in FP32 (cast input up, flip the working dtype so the core's op outputs + eps scalar are + // FP32, then cast 1/rms back down). The FP16 channel reduction loses too much precision on the + // ANE; only this core is FP32 - the scaling/weight/mask below stay FP16. No addConstOp lives in + // the flipped window, so weight serialization is unaffected. + auto savedDtype = m_weight_dtype; + std::string sqSrc = input; + if (m_use_fp16) { + sqSrc = genVarName(prefix + "_in32"); + addCastOp(block, input, sqSrc, "fp32", {-1, C, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } std::string sq = genVarName(prefix + "_sq"); - emit2("mul", input, input, sq, {-1, C, H, W}); + emit2("mul", sqSrc, sqSrc, sq, {-1, C, H, W}); // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can // overflow FP16 (and the FP16 accumulation on ANE) for large activations. @@ -1739,13 +1847,19 @@ std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Bl // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. std::string epsName = prefix + "_eps"; addFloatScalarConstOp(block, epsName, desc.epsilon); - std::string inv = genVarName(prefix + "_inv"); + std::string invCore = genVarName(prefix + "_inv"); { auto* op = block->add_operations(); op->set_type("rsqrt"); (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); - setShape(op, inv, {-1, 1, H, W}); + setShape(op, invCore, {-1, 1, H, W}); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", {-1, 1, H, W}); } std::string normalized = genVarName(prefix + "_norm"); emit2("mul", input, inv, normalized, {-1, C, H, W}); @@ -1788,8 +1902,22 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b setShape(op, out, dims); }; + // Variance core (mask -> square -> reduce -> rsqrt) in FP32 when in FP16 mode. The trunk-tip + // norm in particular reduces over many elements and loses too much precision in FP16 on the + // ANE; compute the core in FP32 and cast 1/rms back to FP16. Only the core is FP32 - gamma/beta, + // the activation and the final mask below stay FP16. No addConstOp lives in the flipped window. + auto savedDtype = m_weight_dtype; + std::string tinput = input; + std::string tmask = mask; + if (m_use_fp16) { + tinput = genVarName(prefix + "_in32"); + addCastOp(block, input, tinput, "fp32", {-1, C, H, W}); + tmask = genVarName(prefix + "_mask32"); + addCastOp(block, mask, tmask, "fp32", {-1, 1, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } std::string masked = genVarName(prefix + "_premask"); - emit2("mul", input, mask, masked, {-1, C, H, W}); + emit2("mul", tinput, tmask, masked, {-1, C, H, W}); std::string sq = genVarName(prefix + "_sq"); emit2("mul", masked, masked, sq, {-1, C, H, W}); @@ -1813,7 +1941,7 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b setShape(op, meanAll, {-1, 1, 1, 1}); } std::string count = genVarName(prefix + "_count"); - reduceSum(mask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + reduceSum(tmask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) std::string totalPosName = prefix + "_totalpos"; addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); std::string scaleF = genVarName(prefix + "_scalef"); @@ -1839,13 +1967,19 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. std::string epsName = prefix + "_eps"; addFloatScalarConstOp(block, epsName, desc.epsilon); - std::string inv = genVarName(prefix + "_inv"); + std::string invCore = genVarName(prefix + "_inv"); { auto* op = block->add_operations(); op->set_type("rsqrt"); (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); - setShape(op, inv, denomDims); + setShape(op, invCore, denomDims); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", denomDims); } std::string normalized = genVarName(prefix + "_norm"); emit2("mul", input, inv, normalized, {-1, C, H, W}); @@ -1938,11 +2072,25 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); std::string x2d = genVarName(prefix + "_x2d"); reshape(nhwc, x2d, {-1, C}, {-1, C}); + // Q/K/V projection matmuls in FP32 (non-spatial, per KataGo's FP16 convention): they reduce over + // C channels and the ANE's FP16 accumulation loses too much precision for wide models. Weights + // stay fp16-stored (cast up at runtime); output cast back to FP16 for the FP16 head reshapes. auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { std::string wName = nm + "_w"; addConstOp(block, wName, w.weights, w.getWeightShape()); std::string out = genVarName(nm); - matmul(x2d, wName, out, {-1, total}, false, false); + if (m_nonspatial_fp32) { + std::string x32 = castFixed(block, x2d, "fp32", {-1, C}); + std::string w32 = castFixed(block, wName, "fp32", w.getWeightShape()); + auto sd = m_weight_dtype; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string o32 = genVarName(nm + "_f32"); + matmul(x32, w32, o32, {-1, total}, false, false); + m_weight_dtype = sd; + out = castFixed(block, o32, "fp16", {-1, total}); + } else { + matmul(x2d, wName, out, {-1, total}, false, false); + } return out; }; std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); @@ -2218,14 +2366,29 @@ std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec: std::string x2d = genVarName(prefix + "_x2d"); reshape(nhwc, x2d, {-1, C}, {-1, C}); + // FFN matmuls in FP32 (weights cast up at runtime, stored fp16) — KataGo's FP16 convention is + // spatial(convs)=FP16, non-spatial(matmuls)=FP32 (see openclbackend.cpp). The ANE accumulates + // FP16 matmuls in FP16, which loses too much precision over C/ffn; run them in FP32 instead. std::string w1 = prefix + "_w1"; addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); - std::string a = genVarName(prefix + "_a"); - matmul(x2d, w1, a, {-1, ffn}); std::string wg = prefix + "_wg"; addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + + auto savedDtype = m_weight_dtype; + std::string mx2d = x2d, mw1 = w1, mwg = wg, mw2 = w2; + if (m_nonspatial_fp32) { + mx2d = castFixed(block, x2d, "fp32", {-1, C}); + mw1 = castFixed(block, w1, "fp32", desc.linear1.getWeightShape()); + mwg = castFixed(block, wg, "fp32", desc.linear_gate.getWeightShape()); + mw2 = castFixed(block, w2, "fp32", desc.linear2.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string a = genVarName(prefix + "_a"); + matmul(mx2d, mw1, a, {-1, ffn}); std::string g = genVarName(prefix + "_g"); - matmul(x2d, wg, g, {-1, ffn}); + matmul(mx2d, mwg, g, {-1, ffn}); std::string sig = genVarName(prefix + "_sig"); { @@ -2239,10 +2402,13 @@ std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec: std::string h = genVarName(prefix + "_h"); binary("mul", siluA, g, h, {-1, ffn}); - std::string w2 = prefix + "_w2"; - addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); - std::string o = genVarName(prefix + "_o"); - matmul(h, w2, o, {-1, C}); + std::string oCore = genVarName(prefix + "_o"); + matmul(h, mw2, oCore, {-1, C}); + std::string o = oCore; + if (m_nonspatial_fp32) { + m_weight_dtype = savedDtype; + o = castFixed(block, oCore, "fp16", {-1, C}); + } std::string oNHWC = genVarName(prefix + "_onhwc"); reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); @@ -2443,9 +2609,25 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling + // Global pooling. Non-spatial per KataGo's FP16 convention -> FP32 (openclbackend.cpp: pooling + // an FP16 tensor produces FP32 pooled values). The spatial sum over H*W loses too much precision + // in FP16 at larger board sizes, corrupting the bias fed back into the whole trunk. No + // addConstOp in the pooling -> flipping m_weight_dtype is safe. std::string gpool_features = genVarName(prefix + "_gpool_features"); - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string gpIn32 = castFixed(block, gpool_bn_out, "fp32", {-1, block_desc.gpool_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string gpMask = mask; + if (!m_optimize_identity_mask) + gpMask = castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string gpOut32 = genVarName(prefix + "_gpool_features_f32"); + addGlobalPoolingOps(block, gpIn32, gpMask, block_desc.gpool_conv.out_channels, gpOut32); + m_weight_dtype = savedDtype; + gpool_features = castFixed(block, gpOut32, "fp16", {-1, block_desc.gpool_conv.out_channels * 3}); + } else { + addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + } // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -2577,9 +2759,22 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 + // Global pooling on G1 — non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum + // loses precision; feeds the policy bias, affecting policyKLDiv). No addConstOp in pooling. std::string g1_pooled = genVarName("policy_g1_pool"); - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string gpIn32 = castFixed(block, g1, "fp32", {-1, ph.g1_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string gpMask = m_optimize_identity_mask ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string gpOut32 = genVarName("policy_g1_pool_f32"); + addGlobalPoolingOps(block, gpIn32, gpMask, ph.g1_conv.out_channels, gpOut32); + m_weight_dtype = savedDtype; + addCastOp(block, gpOut32, g1_pooled, "fp16", {-1, ph.g1_conv.out_channels * 3}); + } else { + addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + } // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2669,9 +2864,21 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) + // Global pooling (value head version) — non-spatial -> FP32 (KataGo FP16 convention). std::string v1_pooled = genVarName("value_v1_pool"); - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string vpIn32 = castFixed(block, v1, "fp32", {-1, vh.v1_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string vpMask = m_optimize_identity_mask ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string vpOut32 = genVarName("value_v1_pool_f32"); + addGlobalPoolingValueOps(block, vpIn32, vpMask, vh.v1_conv.out_channels, vpOut32); + m_weight_dtype = savedDtype; + addCastOp(block, vpOut32, v1_pooled, "fp16", {-1, vh.v1_conv.out_channels * 3}); + } else { + addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + } // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 5d25b963a..ad67f150f 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -43,6 +43,14 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; + // FP32 in FP16 mode follows KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), but + // FP32 ops run off the FP16-only ANE, so convs are channel-gated to only the wide trunks that + // need it. RMSNorm reductions: always FP32 (cheap, needed by all). Non-spatial matmuls+pooling: + // always FP32 (every width needs it at some board size). Convs: FP32 only for wide trunks. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; @@ -102,6 +110,14 @@ class MILBuilder { const std::string& dtype, const std::vector& shape); + // Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for + // weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the + // new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim. + std::string castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..1c625acdd 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -14,12 +14,14 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { WeightEntry entry; entry.name = name; entry.data = data; entry.shape = shape; entry.blob_offset = 0; // Will be set during serialization + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..a9d2a1466 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -16,6 +16,8 @@ struct WeightEntry { std::vector data; std::vector shape; uint64_t blob_offset = 0; // Set during serialization + bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an + // FP32 sub-region of an otherwise-FP16 model). Else stored per global mode. }; /// Precomputed constants for identity mask optimization @@ -51,10 +53,11 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight tensor and return its reference name. is_fp32 marks it for FP32 storage. std::string registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); /// Get all registered weights const std::vector& getWeights() const { return m_weights; } diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..e8fe861c8 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,7 +15,11 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { - if (use_fp16) { + // Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not + // declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so + // stored bytes stay consistent with each const's declared dtype. + const bool store_fp16 = use_fp16 && !entry.is_fp32; + if (store_fp16) { // Convert FP32 weights to FP16 std::vector fp16_data(entry.data.size()); for (size_t i = 0; i < entry.data.size(); ++i) { From 3eb81ce66a410e0515ba2d37787789c4e8539dd3 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 07:40:26 +0800 Subject: [PATCH 13/18] Refactor: dedupe CoreML global-pooling FP32 wrap Three near-identical blocks wrapped global pooling in FP32 (policy head, value head, gpool residual block): cast input/mask up to FP32, flip m_weight_dtype, pool, restore, cast pooled features back to FP16 - with inconsistent save-variable names and one site using castFixed vs addCastOp for the output cast. Extract a single addGlobalPoolingFp32(input, mask, channels, output, valueVariant) helper and a small RAII ScopedFp32 guard for the temporary m_weight_dtype flip. The three call sites become one-liners. Behavior-preserving: same emitted op sequence; testgpuerror output is byte-identical across all precision tiers (partial-FP32 b10c384h6, full-FP32 b7c96h3, non-spatial-FP32 b4c256h4), all 12 transformer gate runs pass, runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 95 ++++++++++--------- .../katagocoreml/src/builder/MILBuilder.hpp | 9 ++ 2 files changed, 57 insertions(+), 47 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index bef0fea73..b4974aa39 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -11,6 +11,20 @@ namespace katagocoreml { +namespace { +// RAII: set a dtype slot to FLOAT32 for the current scope and restore it on exit. Used to emit a +// sub-region of ops in FP32 inside an otherwise-FP16 model. +struct ScopedFp32 { + CoreML::Specification::MILSpec::DataType& slot; + CoreML::Specification::MILSpec::DataType saved; + explicit ScopedFp32(CoreML::Specification::MILSpec::DataType& s) + : slot(s), saved(s) { s = CoreML::Specification::MILSpec::DataType::FLOAT32; } + ~ScopedFp32() { slot = saved; } + ScopedFp32(const ScopedFp32&) = delete; + ScopedFp32& operator=(const ScopedFp32&) = delete; +}; +} // namespace + MILBuilder::MILBuilder(const KataGoModelDesc& model, int board_x_size, int board_y_size, @@ -486,6 +500,32 @@ std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, return out; } +void MILBuilder::addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant) { + auto pool = [&](const std::string& in, const std::string& msk, const std::string& out) { + if (valueVariant) addGlobalPoolingValueOps(block, in, msk, channels, out); + else addGlobalPoolingOps(block, in, msk, channels, out); + }; + // Non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum over H*W loses too much + // precision at larger board sizes). No addConstOp in the pooling, so flipping m_weight_dtype is + // safe. Cast input/mask up, pool in FP32, cast the [N, channels*3] features back to FP16. + if (!m_nonspatial_fp32) { + pool(input, mask, output); + return; + } + std::string in32 = castFixed(block, input, "fp32", {-1, channels, m_board_y_size, m_board_x_size}); + std::string mask32 = m_optimize_identity_mask + ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + std::string out32 = genVarName(output + "_f32"); + { ScopedFp32 g(m_weight_dtype); pool(in32, mask32, out32); } + addCastOp(block, out32, output, "fp16", {-1, channels * 3}); +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -2609,25 +2649,11 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling. Non-spatial per KataGo's FP16 convention -> FP32 (openclbackend.cpp: pooling - // an FP16 tensor produces FP32 pooled values). The spatial sum over H*W loses too much precision - // in FP16 at larger board sizes, corrupting the bias fed back into the whole trunk. No - // addConstOp in the pooling -> flipping m_weight_dtype is safe. + // Global pooling (FP32 when m_nonspatial_fp32 -- see addGlobalPoolingFp32). Feeds a bias back + // into the whole trunk, so the FP16 spatial sum must not lose precision for wide trunks. std::string gpool_features = genVarName(prefix + "_gpool_features"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string gpIn32 = castFixed(block, gpool_bn_out, "fp32", {-1, block_desc.gpool_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string gpMask = mask; - if (!m_optimize_identity_mask) - gpMask = castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string gpOut32 = genVarName(prefix + "_gpool_features_f32"); - addGlobalPoolingOps(block, gpIn32, gpMask, block_desc.gpool_conv.out_channels, gpOut32); - m_weight_dtype = savedDtype; - gpool_features = castFixed(block, gpOut32, "fp16", {-1, block_desc.gpool_conv.out_channels * 3}); - } else { - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); - } + addGlobalPoolingFp32(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features, + /*valueVariant=*/false); // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -2759,22 +2785,9 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 — non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum - // loses precision; feeds the policy bias, affecting policyKLDiv). No addConstOp in pooling. + // Global pooling on G1 (FP32 when m_nonspatial_fp32; feeds the policy bias / policyKLDiv). std::string g1_pooled = genVarName("policy_g1_pool"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string gpIn32 = castFixed(block, g1, "fp32", {-1, ph.g1_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string gpMask = m_optimize_identity_mask ? mask - : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string gpOut32 = genVarName("policy_g1_pool_f32"); - addGlobalPoolingOps(block, gpIn32, gpMask, ph.g1_conv.out_channels, gpOut32); - m_weight_dtype = savedDtype; - addCastOp(block, gpOut32, g1_pooled, "fp16", {-1, ph.g1_conv.out_channels * 3}); - } else { - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); - } + addGlobalPoolingFp32(block, g1, mask, ph.g1_conv.out_channels, g1_pooled, /*valueVariant=*/false); // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2864,21 +2877,9 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) — non-spatial -> FP32 (KataGo FP16 convention). + // Global pooling (value head version; FP32 when m_nonspatial_fp32). std::string v1_pooled = genVarName("value_v1_pool"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string vpIn32 = castFixed(block, v1, "fp32", {-1, vh.v1_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string vpMask = m_optimize_identity_mask ? mask - : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string vpOut32 = genVarName("value_v1_pool_f32"); - addGlobalPoolingValueOps(block, vpIn32, vpMask, vh.v1_conv.out_channels, vpOut32); - m_weight_dtype = savedDtype; - addCastOp(block, vpOut32, v1_pooled, "fp16", {-1, vh.v1_conv.out_channels * 3}); - } else { - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); - } + addGlobalPoolingFp32(block, v1, mask, vh.v1_conv.out_channels, v1_pooled, /*valueVariant=*/true); // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index ad67f150f..fe63b442f 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -118,6 +118,15 @@ class MILBuilder { const std::string& dtype, const std::vector& dims); + // Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool, + // cast the pooled features back to FP16). valueVariant selects the value-head pooling variant. + void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, From d052d2a1be587360e89f9516673fcb6cde4b707c Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:44:59 +0800 Subject: [PATCH 14/18] Fix CoreML/ANE convnet regression: scope FP32 tiers to transformers The width-keyed precision tiers (commit 3839e529) forced FP32 ops off the FP16-only ANE on plain production convnets, not just transformers. b18c384nbt ran ~2.6x slower on the ANE path (160 vs 416 visits/s) with no accuracy benefit. The dominant cost is the per-block global-pooling FP32 (non-spatial), which breaks the ANE pipeline once per gpool-residual block; conv-FP32 is secondary. Add a recursive blocksContainTransformer() helper and gate all three escalations (full-FP32, non-spatial-FP32, conv-FP32) on transformer-block presence. Convnets now run pure FP16 on the ANE (the long-standing pre-tier path); for transformer models the added "&& hasTransformer" is always true, so their emitted MIL is byte-identical and behavior is unchanged. Verified on the ANE FP16 path: b18c384nbt testgpuerror passes (winrate 99%=0.57%, max=0.87%) and recovers full throughput (424 visits/s); b28c512nbt passes (99%=0.41%); all 4 transformer test models x sizes 9/13/19 pass with numbers byte-identical to before; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 45 ++++++++++++++----- .../katagocoreml/src/builder/MILBuilder.hpp | 18 ++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index b4974aa39..09ab365ff 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -23,6 +23,25 @@ struct ScopedFp32 { ScopedFp32(const ScopedFp32&) = delete; ScopedFp32& operator=(const ScopedFp32&) = delete; }; + +// True if any block in this list is a transformer (attention/FFN), recursing into nested-bottleneck +// blocks (which can themselves contain transformer blocks). Used to scope the off-ANE FP32 +// escalations to transformer trunks only. +bool blocksContainTransformer(const std::vector& blocks) { + for (const auto& entry : blocks) { + if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND || + entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + return true; + } + if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { + const auto& nbt = std::get(*entry.block); + if (blocksContainTransformer(nbt.blocks)) { + return true; + } + } + } + return false; +} } // namespace MILBuilder::MILBuilder(const KataGoModelDesc& model, @@ -46,22 +65,28 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) , m_var_counter(0) { - // Precision tiers in FP16 mode (the ANE accumulates FP16 in FP16; FP32 ops run off the FP16-only - // ANE). NARROW transformer trunks are unreliable on the FP16 ANE: their policy/value metrics sit - // right on the testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial - // FP32 leaves a noisy FP16 spatial stream). So build narrow trunks FULLY in FP32 (off-ANE, but - // cheap since narrow models are small; correct because it equals the FP32 reference). Weights are - // stored FP32 via per-weight serialization. Wider trunks use partial FP32: non-spatial (matmuls + - // pooling) always FP32; convs FP32 only for very wide trunks (kept on the ANE for narrower ones). + // Precision in FP16 mode. The ANE accumulates FP16 in FP16, so any FP32 op runs OFF the FP16-only + // ANE (on CPU/GPU), breaking the ANE pipeline. These off-ANE FP32 escalations are applied ONLY to + // transformer trunks, whose attention blocks widen the activation range enough to overflow FP16 + // accumulation. Plain convnets stay PURE FP16 on the ANE -- the long-standing pre-tier path, verified + // to pass testgpuerror (b18c384nbt, b28c512nbt) and ~2.6x faster than forcing their per-block global + // pooling and convs to FP32 (measured: the per-block pooling round-trips, not the convs, dominate the + // slowdown). For transformers: + // - NARROW trunks (<256ch) build FULLY in FP32: their policy/value metrics sit right on the + // testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial FP32 leaves a + // noisy FP16 spatial stream). Off-ANE but cheap since narrow; equals the FP32 reference. Weights + // stored FP32 (per-weight serialization). + // - WIDER trunks use partial FP32: non-spatial (matmuls + pooling) always, convs only for >=320ch. const int trunkChannels = model.trunk.trunk_num_channels; - const bool full_fp32 = use_fp16 && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + const bool hasTransformer = blocksContainTransformer(model.trunk.blocks); + const bool full_fp32 = use_fp16 && hasTransformer && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; if (full_fp32) { m_use_fp16 = false; m_use_fp16_io = false; m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; } - m_nonspatial_fp32 = m_use_fp16; - m_conv_fp32 = m_use_fp16 && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; + m_nonspatial_fp32 = m_use_fp16 && hasTransformer; + m_conv_fp32 = m_use_fp16 && hasTransformer && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; } void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index fe63b442f..e38afb05e 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -43,14 +43,16 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; - // FP32 in FP16 mode follows KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), but - // FP32 ops run off the FP16-only ANE, so convs are channel-gated to only the wide trunks that - // need it. RMSNorm reductions: always FP32 (cheap, needed by all). Non-spatial matmuls+pooling: - // always FP32 (every width needs it at some board size). Convs: FP32 only for wide trunks. - static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // convs run FP32 at/above this width - static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // trunks below this build fully FP32 - bool m_nonspatial_fp32 = false; // = m_use_fp16 (matmuls + global pooling) - bool m_conv_fp32 = false; // = m_use_fp16 && trunk_channels >= CONV_FP32_MIN_... + // FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer + // trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation). + // Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass + // testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32. + // For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls + + // pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; From 145902423e700bbb3348bba45346ce9ade711d46 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 19:08:35 +0800 Subject: [PATCH 15/18] Fix RoPE cos/sin table buffer leak in Metal backend makeRopeTables allocated cosBuf/sinBuf with UnsafeMutablePointer.allocate and handed them to Data(floatsNoCopy:), which uses deallocator: .none, so the buffers were never freed -- a leak on every graph build (per attention block, per board size). Unlike the other floatsNoCopy callers (weights/gamma/beta), which point at C++-descriptor memory that lives for the model's lifetime, these tables have no persistent owner. Switch to managed [Float32] arrays and copy into the Data via Data(buffer:) so MPSGraph owns the bytes -- avoids both the leak and a use-after-free that a naive deallocate() on the no-copy path would cause. Output-neutral: testgpuerror on the GQA + learnable-RoPE model (b7c96h6kv3qk32v16tflrs, board 19) vs Eigen FP32 reference matches to 0.00028% max winrate error over 2247 positions. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index dc4d22d53..96667f3ed 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -1438,8 +1438,11 @@ struct TransformerAttentionBlock { nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { let count = nHeads * seq * numPairs - let cosBuf = UnsafeMutablePointer.allocate(capacity: count) - let sinBuf = UnsafeMutablePointer.allocate(capacity: count) + // Managed arrays (freed on return). Unlike the weight constants elsewhere, which point at + // C++-owned descriptor memory and so use floatsNoCopy, these tables have no persistent owner; + // we copy them into the Data below so MPSGraph owns the bytes (avoids a leak / use-after-free). + var cosBuf = [Float32](repeating: 0, count: count) + var sinBuf = [Float32](repeating: 0, count: count) let numPairsPerDim = numPairs / 2 let dimHalf = qHeadDim / 2 for h in 0.. Date: Tue, 2 Jun 2026 21:54:45 +0800 Subject: [PATCH 16/18] Guard non-SwiGLU transformer FFN in Metal backend The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path (SiLU(linear1) * gate). A non-SwiGLU model carries no gate weights, so building the Swift descriptor from the empty linearGate would crash obscurely (or silently misbehave). Eigen (eigenbackend.cpp) and CoreML (katagocoreml MILBuilder) both throw a clear "non-SwiGLU transformer FFN not supported" error in this case; the Metal GPU path had no such guard. Add the matching StringError at the FFN descriptor conversion so all three backends fail loudly and consistently. No behavior change for any current model (all use useSwiGLU=true): the guard sits on an untaken path. Verified the SwiGLU model b10c384h6nbttflrs still passes testgpuerror on both GPU (0.00005% winrate) and ANE (unchanged from baseline); runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metalbackend.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 75a4e9a93..10ee50b62 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -256,6 +256,11 @@ SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const Trans /// Convert a transformer FFN block description from C++ to Swift SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + // The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path + // (SiLU(linear1) * gate); a non-SwiGLU model has no gate weights, so guard here as Eigen and CoreML + // do (eigenbackend.cpp / katagocoreml MILBuilder) instead of crashing on the empty gate descriptor. + if(!desc->useSwiGLU) + throw StringError(desc->name + ": non-SwiGLU transformer FFN not supported in Metal backend"); SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); From 39f82f6d2ef013e73ac56968c7bd8cd8e4b9576e Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:31:04 +0800 Subject: [PATCH 17/18] Use named constant for trunk norm kind in Metal backend The trunk-tip dispatch in metallayers.swift compared trunkNormKind against the literal 1, while the rest of the codebase uses the named constants from desc.h (TRUNK_NORM_KIND_STANDARD/_RMSNORM). Add matching Swift constants and use TRUNK_NORM_KIND_RMSNORM at the comparison site. Pure literal-to-named-constant rename; no behavior change. Verified both branches still pass testgpuerror at GPU-level accuracy: RMSNorm tip (b10c384h6nbttflrs) 0.00005% winrate, BatchNorm tip (b7c96h6kv3 GQA) 0.00003%; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index 96667f3ed..e1324df96 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -2014,6 +2014,10 @@ class SGFMetadataEncoder { // MARK: - Trunk +/// Trunk-tip normalization kind, mirroring desc.h TRUNK_NORM_KIND_* (the value is serialized in the model). +let TRUNK_NORM_KIND_STANDARD = 0 // BatchNorm or BiasMask (existing) +let TRUNK_NORM_KIND_RMSNORM = 1 // RMSNorm + /// A class that describes a trunk for a neural network public class SWTrunkDesc { let version: Int @@ -2184,9 +2188,8 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - // TRUNK_NORM_KIND_RMSNORM == 1: trunk tip uses RMSNorm with a fused activation. - // Otherwise (standard): BatchNorm followed by a separate activation. - if descriptor.trunkNormKind == 1 { + // RMSNorm trunk tip uses a fused activation; standard uses BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == TRUNK_NORM_KIND_RMSNORM { let trunkTipRMSNorm = TrunkRMSNormLayer( graph: graph, sourceTensor: blocks.resultTensor, From 8481a9411854618befefdc0f25ff15402e2d6c70 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Thu, 4 Jun 2026 12:04:44 +0800 Subject: [PATCH 18/18] Conform CoreML transformer derived consts to the owned-weight + FP32 contract The transformer attention builder emits four function-local std::vector tensors: RoPE cos/sin tables, the rotation matrix R, and per-head out-projection weight slices. After merging the transformer support onto the FloatView branch, these needed two fixes: 1. Dangling view. #1202 made WeightEntry::data a non-owning FloatView, so addConstOp registers a view whose backing buffer must outlive serialization. These locals were passed to addConstOp and would dangle once the build function returns (serialization runs afterwards). Route them through addOwnedConstOp so KataGoOps owns the buffer until serialization. (Under #1205's owning WeightEntry they were copied, so this only surfaces post-merge.) 2. dtype mismatch. emitConstOp declares each const's dtype as m_weight_dtype, but addOwnedConstOp / registerOwnedWeight stored at the global mode (is_fp32 hardcoded false). In an FP16 model these derived consts land in the attention / value-head FP32 sub-region (m_weight_dtype == FLOAT32), so they were declared FP32 but stored FP16. CoreML/ANE then rejects the model at load ("Metadata data type does not match requested type", BNNS error -14), which SIGABRT'd every FP16 ANE transformer. Thread is_fp32 through registerOwnedWeight and have addOwnedConstOp pass is_fp32 = (m_weight_dtype == FLOAT32), mirroring addConstOp so the stored dtype always matches the declared dtype. This also fixes the same latent mismatch for addLinearOp's transposed value-head weights. Verified with testgpuerror against fresh Eigen FP32 references: b7c96h3tfrs and b7c96h6gqa, which previously SIGABRT'd on the FP16 ANE path, now load and match to <0.0005% winrate; convnet ANE output is byte-identical and the Metal GPU path is unchanged. katago runtests and runnnlayertests also pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 21 +++++++++++++------ .../katagocoreml/src/builder/Operations.cpp | 4 +++- .../katagocoreml/src/builder/Operations.hpp | 6 ++++-- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index ba3db1a19..f86181d87 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -281,8 +281,12 @@ void MILBuilder::addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, std::vector&& data, const std::vector& shape) { - // Register derived weight; KataGoOps takes ownership of the buffer - m_ops.registerOwnedWeight(name, std::move(data), shape); + // Register derived/owned weight. Mirror addConstOp's per-weight FP32 marking: emitConstOp + // declares this const's dtype as m_weight_dtype, so the stored bytes must follow the same flag + // or BNNS rejects the model ("Metadata data type does not match requested type") when a derived + // const lands in an FP32 sub-region of an FP16 model. + const bool is_fp32 = (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); + m_ops.registerOwnedWeight(name, std::move(data), shape, is_fp32); emitConstOp(block, name, shape); } @@ -2230,10 +2234,13 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI std::string cosName = prefix + "_" + tag + "_cos"; std::string sinName = prefix + "_" + tag + "_sin"; std::string rName = prefix + "_" + tag + "_R"; - addConstOp(block, cosName, cosFull, {1, nh, seq, qHeadDim}); - addConstOp(block, sinName, sinFull, {1, nh, seq, qHeadDim}); + // cosFull/sinFull/R are locals computed here, so register them as OWNED consts: the + // WeightEntry holds a non-owning FloatView and serialization runs after this lambda + // returns, so a non-owning addConstOp would dangle. + addOwnedConstOp(block, cosName, std::move(cosFull), {1, nh, seq, qHeadDim}); + addOwnedConstOp(block, sinName, std::move(sinFull), {1, nh, seq, qHeadDim}); // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. - addConstOp(block, rName, R, {1, 1, qHeadDim, qHeadDim}); + addOwnedConstOp(block, rName, std::move(R), {1, 1, qHeadDim, qHeadDim}); std::string rotated = genVarName(prefix + "_" + tag + "_rot"); matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); std::string xc = genVarName(prefix + "_" + tag + "_xc"); @@ -2363,7 +2370,9 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI for (int d = 0; d < vHeadDim; d++) for (int c = 0; c < outC; c++) whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; - addConstOp(block, wh, whData, {vHeadDim, outC}); + // whData is a per-head local slice; register OWNED so its FloatView stays valid until + // serialization (a non-owning addConstOp would dangle after this loop iteration). + addOwnedConstOp(block, wh, std::move(whData), {vHeadDim, outC}); std::string contrib = genVarName(prefix + "_contrib"); matmul(aoh2d, wh, contrib, {-1, outC}, false, false); if (h == 0) { diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index 5de42d09c..e86364943 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -28,7 +28,8 @@ std::string KataGoOps::registerWeight(const std::string& name, std::string KataGoOps::registerOwnedWeight(const std::string& name, std::vector&& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { m_owned.push_back(std::move(data)); const std::vector& stored = m_owned.back(); WeightEntry entry; @@ -36,6 +37,7 @@ std::string KataGoOps::registerOwnedWeight(const std::string& name, entry.data = FloatView{stored.data(), stored.size()}; entry.shape = shape; entry.blob_offset = 0; + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 1fb0d92a8..385648d19 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -82,10 +82,12 @@ class KataGoOps { const std::vector& shape) = delete; /// Register a derived/temporary weight; KataGoOps takes ownership so the - /// view stays valid through serialization. + /// view stays valid through serialization. is_fp32 marks it for FP32 storage + /// (mirrors registerWeight) so the stored dtype matches the declared const dtype. std::string registerOwnedWeight(const std::string& name, std::vector&& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); /// Get all registered weights (mutable; serialization sets blob_offset) std::vector& getWeightsMutable() { return m_weights; }