diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 99b81750a..b07dde8fc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -156,6 +156,66 @@ jobs: name: katago-macos-metal path: cpp/katago + build-macos-mlx: + # macos-latest is Apple Silicon (arm64), which the MLX backend requires. + runs-on: macos-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + brew install ninja zlib libzip protobuf abseil mlx + + # The CMake build (build.ninja, CMakeCache.txt) bakes in version-pinned + # Homebrew Cellar paths for protobuf/abseil/mlx (e.g. + # -L/opt/homebrew/Cellar/mlx/0.31.2/lib). When Homebrew bumps those + # the cached paths go stale and the link fails. + # Capture the installed versions into the cache key so + # a formula bump invalidates the cache and forces a fresh configure. + - name: Capture dependency versions for cache key + id: dep-versions + run: | + echo "versions=$(brew list --versions protobuf abseil mlx | tr '\n' '-')" >> "$GITHUB_OUTPUT" + + - name: Cache CMake build + uses: actions/cache@v4 + with: + path: | + cpp/CMakeCache.txt + cpp/CMakeFiles + cpp/build.ninja + cpp/.ninja_deps + cpp/.ninja_log + key: ${{ runner.os }}-cmake-mlx-${{ steps.dep-versions.outputs.versions }}-${{ hashFiles('**/CMakeLists.txt') }} + restore-keys: | + ${{ runner.os }}-cmake-mlx-${{ steps.dep-versions.outputs.versions }}- + + - name: Configure CMake + working-directory: cpp + run: | + cmake . -G Ninja -DUSE_BACKEND=MLX -DCMAKE_BUILD_TYPE=Release + + - name: Build + working-directory: cpp + run: | + ninja + + - name: Run tests + working-directory: cpp + run: | + ./katago runtests + + - name: Upload artifact + if: github.event_name == 'push' && github.ref == 'refs/heads/master' + uses: actions/upload-artifact@v4 + with: + name: katago-macos-mlx + path: cpp/katago + build-windows: runs-on: windows-latest permissions: @@ -185,7 +245,7 @@ jobs: - name: Configure CMake working-directory: cpp run: | - cmake . -G "Visual Studio 17 2022" -A x64 ` + cmake . -A x64 ` -DUSE_BACKEND=OPENCL ` -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" diff --git a/Compiling.md b/Compiling.md index abe7de36f..feaacaffb 100644 --- a/Compiling.md +++ b/Compiling.md @@ -131,8 +131,9 @@ As also mentioned in the instructions below but repeated here for visibility, if * [Homebrew](https://brew.sh): `/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"` * CMake with a minimum version of 3.18.2: `brew install cmake`. * AppleClang and Swift compilers: `xcode-select --install`. - * If using the Metal backend, [Ninja](https://ninja-build.org): `brew install ninja` - * If using the Metal backend, protobuf and abseil: `brew install protobuf abseil` + * If using the Metal or MLX backend, [Ninja](https://ninja-build.org): `brew install ninja` + * If using the Metal or MLX backend, protobuf and abseil: `brew install protobuf abseil` + * If using the MLX backend (Apple Silicon only): `brew install mlx` (≥0.18). Requires CMake ≥3.27. KataGo finds MLX via CMake's default search (Homebrew installs it at `/opt/homebrew/share/cmake/MLX/`); override with `-DMLX_ROOT=/path/to/mlx/cmake` if needed. * libzip: `brew install libzip`. * If you want to do self-play training and research, probably Google perftools `brew install gperftools` for TCMalloc or some other better malloc implementation. For unknown reasons, the allocation pattern in self-play with large numbers of threads and parallel games causes a lot of memory fragmentation under glibc malloc that will eventually run your machine out of memory, but better mallocs handle it fine. * If compiling to contribute to public distributed training runs, OpenSSL is required (`brew install openssl`). @@ -140,14 +141,14 @@ As also mentioned in the instructions below but repeated here for visibility, if * `git clone https://github.com/lightvector/KataGo.git` * Compile using CMake and make in the cpp directory: * `cd KataGo/cpp` - * `cmake . -G Ninja -DUSE_BACKEND=METAL` or `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=EIGEN` depending on which backend you want. + * `cmake . -G Ninja -DUSE_BACKEND=METAL` or `cmake . -G Ninja -DUSE_BACKEND=MLX` or `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=EIGEN` depending on which backend you want. The METAL and MLX backends use Swift/C++ interop, which requires the Ninja generator (`-G Ninja`); the other backends use the default Make generator. * Specify also `-DUSE_TCMALLOC=1` if using TCMalloc. * Compiling will also call git commands to embed the git hash into the compiled executable, specify also `-DNO_GIT_REVISION=1` to disable it if this is causing issues for you. * Specify `-DUSE_AVX2=1` to also compile Eigen with AVX2 and FMA support, which will make it incompatible with old CPUs but much faster. Intel-based Macs with new processors support AVX2, but Apple Silicon Macs do not support AVX2 natively. (If you want to go further, you can also add `-DCMAKE_CXX_FLAGS='-march=native'` which will specialize to precisely your machine's CPU, but the exe might not run on other machines at all). * Specify `-DBUILD_DISTRIBUTED=1` to compile with support for contributing data to public distributed training runs. * If building distributed, you will also need to build with Git revision support, including building within a clone of the repo, as opposed to merely an unzipped copy of its source. * Only builds from specific tagged versions or branches can contribute, in particular, instead of the `master` branch, use either the latest [release](https://github.com/lightvector/KataGo/releases) tag or the tip of the `stable` branch. To minimize the chance of any data incompatibilities or bugs, please do NOT attempt to contribute with custom changes or circumvent these limitations. - * `ninja` for Metal backend, or `make` for other backends. + * `ninja` for the Metal and MLX backends, or `make` for other backends. * Done! You should now have a compiled `katago` executable in your working directory. * Pre-trained neural nets are available at [the main training website](https://katagotraining.org/). * You will probably want to edit `configs/gtp_example.cfg` (see "Tuning for Performance" above). diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5c28fa94e..e1aa865de 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,34 @@ cmake_minimum_required(VERSION 3.18.2) -if(USE_BACKEND STREQUAL "METAL") + +# Pre-project MLX setup. KataGo's MLX path enforces CMake 3.27 via the guard +# below (MLX itself requires only 3.25 - 3.27 is chosen to match +# cmake_policy(VERSION 3.27)); the global cmake_minimum_required stays at +# 3.18.2 so non-MLX backends keep building on older CMake. +# +# The OSX deployment target is deliberately NOT pinned here. KataGo links +# Homebrew's prebuilt libmlx.dylib, whose minos reflects the macOS it was +# bottled on - that dylib, not this build, sets the real minimum macOS. +# Pinning a lower value only stamps a misleading minos on the executable and +# triggers a "linking with dylib built for newer version" linker warning; +# letting CMake default the target to the build host keeps minos honest. +# (A post-project set(CMAKE_OSX_DEPLOYMENT_TARGET) is a silent no-op for this +# Swift project - the target is fixed during project()/enable_language - so it +# is not pinned in the backend branches below either.) +# Normalize the backend name to uppercase BEFORE project(), so the +# case-insensitive behavior of the post-project() string(TOUPPER ...) below +# also applies to the pre-project() MLX version guard and the Swift language +# selection. Without this, a lowercase -DUSE_BACKEND=mlx would silently skip +# the 3.27 guard and the Swift enablement, then still build Swift sources later. +string(TOUPPER "${USE_BACKEND}" USE_BACKEND_NORMALIZED) + +if(USE_BACKEND_NORMALIZED STREQUAL "MLX") + if(CMAKE_VERSION VERSION_LESS 3.27) + message(FATAL_ERROR "KataGo's USE_BACKEND=MLX path requires CMake 3.27 or newer. You have ${CMAKE_VERSION}. Install via: brew install cmake") + endif() + cmake_policy(VERSION 3.27) +endif() + +if(USE_BACKEND_NORMALIZED STREQUAL "METAL" OR USE_BACKEND_NORMALIZED STREQUAL "MLX") project(katago LANGUAGES CXX Swift) else() project(katago) @@ -44,7 +73,7 @@ endif() set(BUILD_DISTRIBUTED 0 CACHE BOOL "Build with http support for contributing to distributed training") set(USE_BACKEND CACHE STRING "Neural net backend") string(TOUPPER "${USE_BACKEND}" USE_BACKEND) -set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN METAL) +set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN MLX METAL) set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc") set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe") @@ -134,7 +163,6 @@ elseif(USE_BACKEND STREQUAL "METAL") include(InitializeSwift) include(AddSwift) - set(CMAKE_OSX_DEPLOYMENT_TARGET 13.0) set(NEURALNET_BACKEND_SOURCES neuralnet/metalbackend.cpp ) @@ -164,8 +192,72 @@ elseif(USE_BACKEND STREQUAL "EIGEN") set(NEURALNET_BACKEND_SOURCES neuralnet/eigenbackend.cpp ) +elseif(USE_BACKEND STREQUAL "MLX") + message(STATUS "-DUSE_BACKEND=MLX, using MLX backend (with CoreML/ANE MUX) for Apple Silicon.") + + if(NOT APPLE) + message(FATAL_ERROR "USE_BACKEND=MLX is only supported on macOS. Detected: ${CMAKE_SYSTEM_NAME}") + endif() + if(CMAKE_OSX_ARCHITECTURES) + if(NOT CMAKE_OSX_ARCHITECTURES STREQUAL "arm64") + message(FATAL_ERROR "USE_BACKEND=MLX requires arm64. Got: ${CMAKE_OSX_ARCHITECTURES}") + endif() + elseif(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(FATAL_ERROR "USE_BACKEND=MLX requires Apple Silicon (arm64). Detected: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + + # CoreML/ANE MUX prerequisites — same constraints the METAL branch above + # enforces (same wording for grep parity). + if(NOT "${CMAKE_GENERATOR}" STREQUAL "Ninja") + message(FATAL_ERROR "Bidirectional C++ Interop requires Ninja generator. Have ${CMAKE_GENERATOR}") + endif() + if("${CMAKE_Swift_COMPILER_VERSION}" VERSION_LESS 5.9) + message(FATAL_ERROR "Bidirectional C++ Interop requires Swift 5.9 or greater. Have ${CMAKE_Swift_COMPILER_VERSION}") + endif() + if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + message(FATAL_ERROR "Project requires building with AppleClang. Have ${CMAKE_CXX_COMPILER_ID}") + endif() + + # katagocoreml provides the native CoreML conversion C++ library used by the ANE mux. + add_subdirectory(external/katagocoreml) + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/external/macos/cmake/modules") + + if (NOT CMAKE_OSX_SYSROOT) + execute_process(COMMAND xcrun --show-sdk-path OUTPUT_VARIABLE CMAKE_OSX_SYSROOT OUTPUT_STRIP_TRAILING_WHITESPACE) + endif() + + include(InitializeSwift) + include(AddSwift) + + set(MLX_MIN_VERSION "0.18") + set(MLX_ROOT "" CACHE PATH "Optional path to MLX's CMake package; leave empty to use CMake's default search (e.g. Homebrew's /opt/homebrew/share/cmake/MLX/)") + + # Homebrew installs MLX's CMake config to /opt/homebrew/share/cmake/MLX/, which is + # on CMake's default search path. MLX_ROOT, when set, is added as an extra hint. + find_package(MLX ${MLX_MIN_VERSION} CONFIG REQUIRED HINTS "${MLX_ROOT}") + message(STATUS "Found MLX ${MLX_VERSION} at ${MLX_LIBRARY}") + + set(NEURALNET_BACKEND_SOURCES + neuralnet/mlxbackend.cpp + neuralnet/mlxwinotuner.cpp + neuralnet/mlxtests.cpp + ) + + # Build the KataGoSwift static library. Same lines as the METAL branch above, + # kept inline to leave the Metal branch untouched. The library exposes + # CoreMLComputeHandle to C++ via the generated KataGoSwift-swift.h. + add_library(KataGoSwift STATIC + neuralnet/metalbackend.swift + neuralnet/metallayers.swift) + _swift_generate_cxx_header( + KataGoSwift + "${CMAKE_CURRENT_BINARY_DIR}/include/KataGoSwift/KataGoSwift-swift.h") + target_include_directories(KataGoSwift PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/include") + set_target_properties(KataGoSwift PROPERTIES Swift_MODULE_NAME "KataGoSwift") + target_compile_options(KataGoSwift PUBLIC + "$<$:-cxx-interoperability-mode=default>") elseif(USE_BACKEND STREQUAL "") - message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN to compile with the respective backend.${ColorReset}") + message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN or -DUSE_BACKEND=MLX or -DUSE_BACKEND=METAL to compile with the respective backend.${ColorReset}") set(NEURALNET_BACKEND_SOURCES neuralnet/dummybackend.cpp) else() message(FATAL_ERROR "Unrecognized backend: " ${USE_BACKEND}) @@ -531,6 +623,10 @@ elseif(USE_BACKEND STREQUAL "EIGEN") message(STATUS "Found Eigen3 at ${EIGEN3_INCLUDE_DIRS}") endif() endif() +elseif(USE_BACKEND STREQUAL "MLX") + target_compile_definitions(katago PRIVATE USE_MLX_BACKEND) + target_link_libraries(katago mlx KataGoSwift katagocoreml + ${KATAGOCOREML_DEP_LDFLAGS}) endif() if(USE_BIGGER_BOARDS_EXPENSIVE) diff --git a/cpp/command/benchmark.cpp b/cpp/command/benchmark.cpp index 45d5312f2..d6b410450 100644 --- a/cpp/command/benchmark.cpp +++ b/cpp/command/benchmark.cpp @@ -266,6 +266,9 @@ int MainCmds::benchmark(const vector& args) { #endif #ifdef USE_EIGEN_BACKEND cout << "You are currently using the Eigen (CPU) version of KataGo. Due to having no GPU, it may be slow." << endl; +#endif +#ifdef USE_MLX_BACKEND + cout << "Your GTP config is currently set to mlxUseFP16 = " << nnEval->getUsingFP16Mode().toString() << endl; #endif cout << endl; cout << "Your GTP config is currently set to use numSearchThreads = " << params.numThreads << endl; diff --git a/cpp/command/tune.cpp b/cpp/command/tune.cpp index f6a089152..a984d43c1 100644 --- a/cpp/command/tune.cpp +++ b/cpp/command/tune.cpp @@ -11,14 +11,16 @@ #include "../neuralnet/opencltuner.h" #endif +#ifdef USE_MLX_BACKEND +#include "../program/setup.h" +#include "../neuralnet/desc.h" +#include "../neuralnet/mlxwinotuner.h" +#endif + using namespace std; -int MainCmds::tuner(const vector& args) { -#ifndef USE_OPENCL_BACKEND - cout << "Currently this command only does anything for the OpenCL version of KataGo" << endl; - (void)args; - return 0; -#else +#if defined(USE_OPENCL_BACKEND) +static int runOpenCLTuner(const vector& args) { ConfigParser cfg; string modelFile; @@ -226,6 +228,142 @@ int MainCmds::tuner(const vector& args) { } return 0; +} +#elif defined(USE_MLX_BACKEND) +static int runMLXTuner(const vector& args) { + + // MLX (Apple GPU) tuner: searches the Winograd input/output transform grids + // and writes the winning parameters to the same cache the backend reads at + // model load. This is the deliberate "command tune" path; the auto-tune that + // runs during normal model load stays coarse/fast. The OpenCL-only FP16 + // sub-knobs (storage/compute/tensorcores) have no MLX analog and are omitted. + ConfigParser cfg; + string modelFile; + string outputFileFromArg; + int nnXLen; + int nnYLen; + int batchSize; + string testFP16Str; + enabled_t testFP16Mode; + bool full; + try { + KataGoCommandLine cmd("Perform Winograd transform tuning for the MLX (Apple GPU) backend."); + cmd.addConfigFileArg(KataGoCommandLine::defaultGtpConfigFileName(),"gtp_example.cfg"); + cmd.addModelFileArg(); + + TCLAP::ValueArg outputFileArg("","output","Filename to output tuning configuration to (default: shared MLX tuner cache)",false,string(),"FILE"); + TCLAP::ValueArg nnXLenArg("","xsize","Width of board to tune for",false,19,"INT"); + TCLAP::ValueArg nnYLenArg("","ysize","Height of board to tune for",false,19,"INT"); + TCLAP::ValueArg batchSizeArg("","batchsize","Batch size to tune for",false,8,"INT"); + TCLAP::ValueArg testFP16Arg("","testFP16","Tune for FP16? true|false|auto (default auto = engine default, FP16)",false,"auto","BOOL_OR_AUTO"); + TCLAP::SwitchArg fullArg("","full","Sweep the wide candidate grid instead of the default coarse one"); + + cmd.setShortUsageArgLimit(); + cmd.addOverrideConfigArg(); + + cmd.add(outputFileArg); + cmd.add(nnXLenArg); + cmd.add(nnYLenArg); + cmd.add(batchSizeArg); + cmd.add(testFP16Arg); + cmd.add(fullArg); + cmd.parseArgs(args); + + modelFile = cmd.getModelFile(); + outputFileFromArg = outputFileArg.getValue(); + nnXLen = nnXLenArg.getValue(); + nnYLen = nnYLenArg.getValue(); + batchSize = batchSizeArg.getValue(); + testFP16Str = testFP16Arg.getValue(); + full = fullArg.getValue(); + + if(!enabled_t::tryParse(testFP16Str,testFP16Mode)) { + cerr << "Error: Could not parse -testFP16 as bool or auto: " << testFP16Str << endl; + return 1; + } + + cmd.getConfigAllowEmpty(cfg); + } + catch (TCLAP::ArgException &e) { + cerr << "Error: " << e.error() << " for argument " << e.argId() << endl; + return 1; + } + + // The MLX GPU path runs FP16 unless explicitly disabled (useFP16Mode != False + // in mlxbackend.cpp), so 'auto' tunes for FP16 - the precision the engine will + // actually use, and the precision the cache filename is keyed on. + const bool useFP16 = (testFP16Mode != enabled_t::False); + + string homeDataDirOverride = Setup::loadHomeDataDirOverride(cfg); + + const bool logToStdoutDefault = true; + Logger logger(&cfg, logToStdoutDefault); + + logger.write("Loading model..."); + ModelDesc modelDesc; + string expectedSha256 = ""; + ModelDesc::loadFromFileMaybeGZipped(modelFile, modelDesc, expectedSha256); + + // Same shape diagnostic the backend logs at load, so the tuned cache can be + // correlated with the model's 3x3 conv shape mix. + logger.write(MLXWinogradTuner::formatConv3x3Distribution(modelDesc)); + + MLXWinogradTuner::ModelInfoForTuning modelInfo; + modelInfo.trunkNumChannels = modelDesc.trunk.trunkNumChannels; + modelInfo.modelVersion = modelDesc.modelVersion; + { + auto histograms = MLXWinogradTuner::buildConv3x3Histograms(modelDesc); + modelInfo.conv3x3InputHistogram = std::move(histograms.first); + modelInfo.conv3x3OutputHistogram = std::move(histograms.second); + } + + // Chip-specific cache key, shared with the backend's model-load path so the + // command writes exactly the file the backend reads; part of the filename key. + const string gpuName = MLXWinogradTuner::detectGpuName(); + + string outputFile; + if(outputFileFromArg == "") { + string dir = MLXWinogradTuner::defaultDirectory(true,homeDataDirOverride); + outputFile = dir + "/" + MLXWinogradTuner::defaultFileName( + gpuName, nnXLen, nnYLen, modelInfo.trunkNumChannels, modelInfo.modelVersion, useFP16, full); + } + else { + outputFile = outputFileFromArg; + } + + logger.write(string("MLX Winograd tuner starting (") + (full ? "full" : "coarse") + + " sweep, " + (useFP16 ? "FP16" : "FP32") + ", batch " + Global::intToString(batchSize) + ")..."); + + // reTune=true: a command tune always re-runs the search and overwrites the + // cache, rather than short-circuiting on an existing file. + MLXWinogradTuner::loadOrAutoTune( + outputFile, + homeDataDirOverride, + gpuName, + nnXLen, + nnYLen, + batchSize, + modelInfo, + &logger, + /*full=*/full, + /*reTune=*/true, + /*useFP16=*/useFP16 + ); + + cout << "Done, results saved to " << outputFile << endl; + + return 0; +} +#endif +int MainCmds::tuner(const vector& args) { +#if defined(USE_OPENCL_BACKEND) + return runOpenCLTuner(args); +#elif defined(USE_MLX_BACKEND) + return runMLXTuner(args); +#else + cout << "Currently this command only does anything for the OpenCL and MLX versions of KataGo" << endl; + (void)args; + return 0; #endif } diff --git a/cpp/configs/analysis_example.cfg b/cpp/configs/analysis_example.cfg index 728014b21..9df0cdea3 100644 --- a/cpp/configs/analysis_example.cfg +++ b/cpp/configs/analysis_example.cfg @@ -298,6 +298,47 @@ nnRandomize = true # It defaults to min(numAnalysisThreads * numSearchThreadsPerAnalysisThread, numCPUCores). # numEigenThreadsPerModel = X +# ------------------------------ +# MLX-specific settings +# ------------------------------ +# These only apply when using the MLX backend (Apple Silicon). + +# MLX backend dispatch is configured via numNNServerThreadsPerModel and mlxDeviceToUseThread. +# Device index values (same convention as the Metal backend): +# 0 = GPU only (MLX) - default +# 100 = ANE only (CoreML, runs on CPU + Apple Neural Engine) +# Any other value is rejected at startup. The backend-agnostic key +# `deviceToUseThread` is also accepted. +# +# Mux mode: pipeline GPU and ANE server threads to overlap their forward +# passes. Set nnMaxBatchSize to roughly half of numSearchThreads. +# +# Example: mux mode (2x GPU + 2x ANE) +# numNNServerThreadsPerModel = 4 +# mlxDeviceToUseThread0 = 0 +# mlxDeviceToUseThread1 = 0 +# mlxDeviceToUseThread2 = 100 +# mlxDeviceToUseThread3 = 100 +# +# Example: GPU-only mode (default) +# numNNServerThreadsPerModel = 1 +# mlxDeviceToUseThread0 = 0 +# +# Example: ANE-only mode (CoreML on CPU+ANE) +# numNNServerThreadsPerModel = 1 +# mlxDeviceToUseThread0 = 100 +# +# Default (no config): 1 server thread, GPU-only mode. + +# Whether to use FP16 (half precision) for neural net evaluation on MLX. +# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path. +# The ANE is FP16-only: on an ANE thread (gpuIdx = 100) with mlxUseFP16 = false, +# CoreML falls back to CPU FP32 - correct but much slower than the GPU path. +# Set `false` only for bit-exact FP32 reproducibility. +# +# Default: auto (resolves to fp16 on MLX). +# mlxUseFP16 = auto + # Misc Behavior -------------------- diff --git a/cpp/configs/contribute_example.cfg b/cpp/configs/contribute_example.cfg index 6ca039f11..0839560d4 100644 --- a/cpp/configs/contribute_example.cfg +++ b/cpp/configs/contribute_example.cfg @@ -139,3 +139,31 @@ watchOngoingGameInFileName = watchgame.txt # This is the number of CPU threads for evaluating the neural net on the Eigen backend. # It defaults to numSearchThreads. # numEigenThreadsPerModel = X + +# ------------------------------ +# MLX-specific settings +# ------------------------------ +# These only apply when using the MLX backend (Apple Silicon). + +# Per-server-thread dispatch (same convention as the Metal backend): +# 0 = GPU via MLX (default) +# 100 = ANE via CoreML (CPU + Apple Neural Engine) +# Mix in one config to pipeline GPU and ANE work. The backend-agnostic key +# `deviceToUseThread` is also accepted. +# +# Example: mux mode (2x GPU + 2x ANE) - also set numNNServerThreadsPerModel = 4 above +# mlxDeviceToUseThread0 = 0 +# mlxDeviceToUseThread1 = 0 +# mlxDeviceToUseThread2 = 100 +# mlxDeviceToUseThread3 = 100 +# +# Example: ANE-only single instance +# mlxDeviceToUseThread0 = 100 + +# Whether to use FP16 (half precision) for neural net evaluation on MLX. +# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path. +# The ANE is FP16-only: on an ANE thread (gpuIdx = 100) with mlxUseFP16 = false, +# CoreML falls back to CPU FP32 - correct but much slower than the GPU path. +# +# Default: auto (resolves to fp16 on MLX). +# mlxUseFP16 = auto diff --git a/cpp/configs/gtp_example.cfg b/cpp/configs/gtp_example.cfg index 8a261c4c3..618b5913a 100644 --- a/cpp/configs/gtp_example.cfg +++ b/cpp/configs/gtp_example.cfg @@ -539,6 +539,52 @@ searchFactorWhenWinningThreshold = 0.95 # Default: numSearchThreads # numEigenThreadsPerModel = X +# ------------------------------ +# MLX-specific settings +# ------------------------------ +# These only apply when using the MLX backend (Apple Silicon). + +# MLX backend dispatch is configured via numNNServerThreadsPerModel and mlxDeviceToUseThread. +# Device index values (same convention as the Metal backend): +# 0 = GPU only (MLX) - default +# 100 = ANE only (CoreML, runs on CPU + Apple Neural Engine) +# Any other value is rejected at startup. The backend-agnostic key +# `deviceToUseThread` is also accepted if you prefer not to commit to a +# backend-specific prefix. +# +# Mux mode: pipeline GPU and ANE server threads to overlap their forward +# passes. Set nnMaxBatchSize to roughly half of numSearchThreads for best +# pipelining. +# +# Example: mux mode (2x GPU + 2x ANE) +# numNNServerThreadsPerModel = 4 +# mlxDeviceToUseThread0 = 0 +# mlxDeviceToUseThread1 = 0 +# mlxDeviceToUseThread2 = 100 +# mlxDeviceToUseThread3 = 100 +# +# Example: GPU-only mode (default) +# numNNServerThreadsPerModel = 1 +# mlxDeviceToUseThread0 = 0 +# +# Example: ANE-only mode (CoreML on CPU+ANE; ~3 search threads is the +# observed throughput sweet spot since a single CoreML call serializes +# per batch) +# numNNServerThreadsPerModel = 1 +# mlxDeviceToUseThread0 = 100 +# +# Default (no config): 1 server thread, GPU-only mode. + +# Whether to use FP16 (half precision) for neural net evaluation on MLX. +# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path. +# The ANE is FP16-only hardware: on an ANE thread (gpuIdx = 100) with +# mlxUseFP16 = false, CoreML falls back to CPU FP32 - correct but much +# slower than the GPU path. Set `false` only for bit-exact FP32 +# reproducibility. +# +# Default: auto (resolves to fp16 on MLX). +# mlxUseFP16 = auto + # =========================================================================== # Root move selection and biases # =========================================================================== diff --git a/cpp/configs/match_example.cfg b/cpp/configs/match_example.cfg index 7e5b4fc09..992b48303 100644 --- a/cpp/configs/match_example.cfg +++ b/cpp/configs/match_example.cfg @@ -197,6 +197,34 @@ numNNServerThreadsPerModel = 1 # It defaults to numSearchThreads. # numEigenThreadsPerModel = X +# ------------------------------ +# MLX-specific settings +# ------------------------------ +# These only apply when using the MLX backend (Apple Silicon). + +# Per-server-thread dispatch (same convention as the Metal backend): +# 0 = GPU via MLX (default) +# 100 = ANE via CoreML (CPU + Apple Neural Engine) +# Mix in one config to pipeline GPU and ANE work. The backend-agnostic key +# `deviceToUseThread` is also accepted. +# +# Example: mux mode (2x GPU + 2x ANE) - also set numNNServerThreadsPerModel = 4 above +# mlxDeviceToUseThread0 = 0 +# mlxDeviceToUseThread1 = 0 +# mlxDeviceToUseThread2 = 100 +# mlxDeviceToUseThread3 = 100 +# +# Example: ANE-only single instance +# mlxDeviceToUseThread0 = 100 + +# Whether to use FP16 (half precision) for neural net evaluation on MLX. +# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path. +# The ANE is FP16-only: on an ANE thread (gpuIdx = 100) with mlxUseFP16 = false, +# CoreML falls back to CPU FP32 - correct but much slower than the GPU path. +# +# Default: auto (resolves to fp16 on MLX). +# mlxUseFP16 = auto + # Root move selection and biases------------------------------------------------------------------------------ # Uncomment and edit any of the below values to change them from their default. diff --git a/cpp/external/katagocoreml/src/Converter.cpp b/cpp/external/katagocoreml/src/Converter.cpp index cb6ca80d9..52cb615bf 100644 --- a/cpp/external/katagocoreml/src/Converter.cpp +++ b/cpp/external/katagocoreml/src/Converter.cpp @@ -29,9 +29,12 @@ void KataGoConverter::convert(const std::string& input_path, throw std::invalid_argument("max_batch_size must be >= min_batch_size or <= 0 for unlimited"); } - // Parse KataGo model - KataGoParser parser(input_path); - KataGoModelDesc model = parser.parse(); + // Parse KataGo model (parser + its decompressed buffer freed at end of scope) + KataGoModelDesc model; + { + KataGoParser parser(input_path); + model = parser.parse(); + } // Determine if using FP16 precision bool use_fp16 = (options.compute_precision == "FLOAT16"); @@ -52,12 +55,17 @@ void KataGoConverter::convert(const std::string& input_path, options.use_fp16_io); auto program = builder.build(); - // Get weights from builder - auto weights = builder.getWeights(); - std::vector weights_copy(weights.begin(), weights.end()); + // Serialize directly from the builder's weight views (no copy). + std::vector& weights = builder.getWeightsMutable(); // Update options with model metadata for serialization ConversionOptions final_options = options; + // The builder may downgrade IO to FP32 for narrow transformers; reflect its + // effective decision so the serialized feature descriptors (and the spec- + // version bump below) match the MIL program's actual IO dtype. Without this, + // a narrow transformer with use_fp16_io=true emits FP32 IO tensors while the + // model spec still declares FP16 IO — a mismatch that fails to load. + final_options.use_fp16_io = builder.getUseFp16Io(); final_options.model_version = model.model_version; final_options.meta_encoder_version = model.meta_encoder_version; final_options.num_input_meta_channels = model.num_input_meta_channels; @@ -82,7 +90,7 @@ void KataGoConverter::convert(const std::string& input_path, // Serialize to .mlpackage CoreMLSerializer serializer(final_options.specification_version); - serializer.serialize(program.get(), weights_copy, output_path, final_options); + serializer.serialize(program.get(), weights, output_path, final_options); } ModelInfo KataGoConverter::getModelInfo(const std::string& input_path) { diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..bd0a1163b 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -4,12 +4,54 @@ #include "MILBuilder.hpp" #include "MILBlob/Fp16.hpp" #include +#include // Include generated protobuf headers #include "MIL.pb.h" namespace katagocoreml { +namespace { +// RAII: while active, force a dtype slot to FLOAT32 and restore it on scope exit. Used to emit a +// sub-region of ops in FP32 inside an otherwise-FP16 model. Pass active=false to make the guard a +// no-op (so callers can guard a conditional FP32 window without branching on construction). restore() +// ends the window early and is idempotent, letting a caller drop back to FP16 before a trailing +// cast-down while still getting an exception-safe restore if an op emission throws in between. +struct ScopedFp32 { + CoreML::Specification::MILSpec::DataType* slot; + CoreML::Specification::MILSpec::DataType saved; + explicit ScopedFp32(CoreML::Specification::MILSpec::DataType& s, bool active = true) + : slot(active ? &s : nullptr), saved(s) { + if (slot) *slot = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + void restore() { + if (slot) { *slot = saved; slot = nullptr; } + } + ~ScopedFp32() { restore(); } + ScopedFp32(const ScopedFp32&) = delete; + ScopedFp32& operator=(const ScopedFp32&) = delete; +}; + +// True if any block in this list is a transformer (attention/FFN), recursing into nested-bottleneck +// blocks (which can themselves contain transformer blocks). Used to scope the off-ANE FP32 +// escalations to transformer trunks only. +bool blocksContainTransformer(const std::vector& blocks) { + for (const auto& entry : blocks) { + if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND || + entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + return true; + } + if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { + const auto& nbt = std::get(*entry.block); + if (blocksContainTransformer(nbt.blocks)) { + return true; + } + } + } + return false; +} +} // namespace + MILBuilder::MILBuilder(const KataGoModelDesc& model, int board_x_size, int board_y_size, @@ -30,7 +72,30 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, ? CoreML::Specification::MILSpec::DataType::FLOAT16 : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) - , m_var_counter(0) {} + , m_var_counter(0) { + // Precision in FP16 mode. The ANE accumulates FP16 in FP16, so any FP32 op runs OFF the FP16-only + // ANE (on CPU/GPU), breaking the ANE pipeline. These off-ANE FP32 escalations are applied ONLY to + // transformer trunks, whose attention blocks widen the activation range enough to overflow FP16 + // accumulation. Plain convnets stay PURE FP16 on the ANE -- the long-standing pre-tier path, verified + // to pass testgpuerror (b18c384nbt, b28c512nbt) and ~2.6x faster than forcing their per-block global + // pooling and convs to FP32 (measured: the per-block pooling round-trips, not the convs, dominate the + // slowdown). For transformers: + // - NARROW trunks (<256ch) build FULLY in FP32: their policy/value metrics sit right on the + // testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial FP32 leaves a + // noisy FP16 spatial stream). Off-ANE but cheap since narrow; equals the FP32 reference. Weights + // stored FP32 (per-weight serialization). + // - WIDER trunks use partial FP32: non-spatial (matmuls + pooling) always, convs only for >=320ch. + const int trunkChannels = model.trunk.trunk_num_channels; + const bool hasTransformer = blocksContainTransformer(model.trunk.blocks); + const bool full_fp32 = use_fp16 && hasTransformer && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + if (full_fp32) { + m_use_fp16 = false; + m_use_fp16_io = false; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + m_nonspatial_fp32 = m_use_fp16 && hasTransformer; + m_conv_fp32 = m_use_fp16 && hasTransformer && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; +} void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { auto* dim = tensor_type->add_dimensions(); @@ -212,9 +277,30 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage - m_ops.registerWeight(name, data, shape); + // Register weight for blob storage (non-owning view into the model). Mark FP32 storage when this + // const is declared FP32 (e.g. inside an FP32 sub-region of an otherwise-FP16 model) so storage + // matches the declared type. + m_ops.registerWeight(name, data, shape, + m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); + emitConstOp(block, name, shape); +} + +void MILBuilder::addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) { + // Register derived/owned weight. Mirror addConstOp's per-weight FP32 marking: emitConstOp + // declares this const's dtype as m_weight_dtype, so the stored bytes must follow the same flag + // or BNNS rejects the model ("Metadata data type does not match requested type") when a derived + // const lands in an FP32 sub-region of an FP16 model. + const bool is_fp32 = (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); + m_ops.registerOwnedWeight(name, std::move(data), shape, is_fp32); + emitConstOp(block, name, shape); +} +void MILBuilder::emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape) { // Add const operation auto* op = block->add_operations(); op->set_type("const"); @@ -328,7 +414,11 @@ void MILBuilder::addFloatScalarConstOp(CoreML::Specification::MILSpec::Block* bl val_type->set_datatype(m_weight_dtype); val_type->set_rank(0); - if (m_use_fp16) { + // Key the storage format off the DECLARED dtype (m_weight_dtype), not the global m_use_fp16: + // a temporarily-flipped FP32 sub-region (m_weight_dtype=FLOAT32 while m_use_fp16 stays true) + // must store FP32 floats, or CoreML rejects the model ("storage and type have different number + // of elements"). For all non-flipped calls m_weight_dtype tracks m_use_fp16, so this is a no-op. + if (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT16) { // For FP16, use bytes storage with FP16 representation MILBlob::Fp16 fp16_val = MILBlob::Fp16::FromFloat(value); std::string bytes_data(reinterpret_cast(&fp16_val.bytes), sizeof(fp16_val.bytes)); @@ -426,6 +516,68 @@ void MILBuilder::addCastOp(CoreML::Specification::MILSpec::Block* block, } } +std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims) { + std::string out = genVarName(input + "_cast"); + std::string dtName = out + "_dt"; + { + auto* op = block->add_operations(); + op->set_type("const"); + auto& na = (*op->mutable_attributes())["name"]; + na.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + na.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtName); + auto& va = (*op->mutable_attributes())["val"]; + va.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + va.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtype); + auto* o = op->add_outputs(); + o->set_name(dtName); + o->mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + } + auto* op = block->add_operations(); + op->set_type("cast"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(input); + (*op->mutable_inputs())["dtype"].add_arguments()->set_name(dtName); + auto* o = op->add_outputs(); + o->set_name(out); + auto* tt = o->mutable_type()->mutable_tensortype(); + tt->set_datatype(dtype == "fp32" ? CoreML::Specification::MILSpec::DataType::FLOAT32 + : CoreML::Specification::MILSpec::DataType::FLOAT16); + tt->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + if (d < 0) tt->add_dimensions()->mutable_unknown()->set_variadic(false); + else tt->add_dimensions()->mutable_constant()->set_size(d); + } + return out; +} + +void MILBuilder::addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant) { + auto pool = [&](const std::string& in, const std::string& msk, const std::string& out) { + if (valueVariant) addGlobalPoolingValueOps(block, in, msk, channels, out); + else addGlobalPoolingOps(block, in, msk, channels, out); + }; + // Non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum over H*W loses too much + // precision at larger board sizes). No addConstOp in the pooling, so flipping m_weight_dtype is + // safe. Cast input/mask up, pool in FP32, cast the [N, channels*3] features back to FP16. + if (!m_nonspatial_fp32) { + pool(input, mask, output); + return; + } + std::string in32 = castFixed(block, input, "fp32", {-1, channels, m_board_y_size, m_board_x_size}); + std::string mask32 = m_optimize_identity_mask + ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + std::string out32 = genVarName(output + "_f32"); + { ScopedFp32 g(m_weight_dtype); pool(in32, mask32, out32); } + addCastOp(block, out32, output, "fp16", {-1, channels * 3}); +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -566,6 +718,20 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, tt->add_dimensions()->mutable_constant()->set_size(4); } + // Channel-gated FP32 convs. The ANE accumulates FP16 convs in FP16, which loses too much + // precision for WIDE trunks and fails testgpuerror at large board sizes (validated: 384ch + // fails, <=256ch is fine FP16-on-ANE). For wide trunks (>= threshold) run convs in FP32 (weights + // cast up at runtime, stored fp16). FP32 convs can't run on the fp16-only ANE, so only the wide + // models that actually need it pay that off-ANE cost; narrow models keep convs on the ANE. + const bool convFp32 = m_conv_fp32; + std::string convX = input, convW = weight_name, convOut = output; + if (convFp32) { + convX = castFixed(block, input, "fp32", {-1, layer.in_channels, m_board_y_size, m_board_x_size}); + convW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + convOut = output + "_cf32"; + } + ScopedFp32 fp32Scope(m_weight_dtype, convFp32); + // Add conv operation referencing all const parameters auto* op = block->add_operations(); op->set_type("conv"); @@ -577,12 +743,12 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, inputs["pad"].add_arguments()->set_name(pad_name); inputs["pad_type"].add_arguments()->set_name(pad_type_name); inputs["strides"].add_arguments()->set_name(strides_name); - inputs["weight"].add_arguments()->set_name(weight_name); - inputs["x"].add_arguments()->set_name(input); + inputs["weight"].add_arguments()->set_name(convW); + inputs["x"].add_arguments()->set_name(convX); // Output with dimensions [batch, out_channels, height, width] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(convOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(4); @@ -590,6 +756,11 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + + if (convFp32) { + fp32Scope.restore(); + addCastOp(block, convOut, output, "fp16", {-1, layer.out_channels, m_board_y_size, m_board_x_size}); + } } // Helper: Set output tensor type with 4D shape [batch, C, H, W] @@ -732,6 +903,63 @@ void MILBuilder::addBatchNormActivationOps(CoreML::Specification::MILSpec::Block setTensorOutput4D(op, output, bn.num_channels, m_board_y_size, m_board_x_size); } else if (act.activation_type == ActivationType::Mish) { addMishOps(block, bn_output, output, 4, bn.num_channels); + } else if (act.activation_type == ActivationType::Silu) { + addSiluOps(block, bn_output, output, 4, bn.num_channels); + } +} + +void MILBuilder::addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels) { + // SiLU / Swish: x * sigmoid(x) + auto setOutputType = [this, rank, channels](CoreML::Specification::MILSpec::Operation* op, const std::string& name) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* out_type = out->mutable_type()->mutable_tensortype(); + out_type->set_datatype(m_weight_dtype); + out_type->set_rank(rank); + setBatchDimension(out_type); + out_type->add_dimensions()->mutable_constant()->set_size(channels); + if (rank == 4) { + out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); + out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + } + }; + + std::string sig = output + "_sigmoid"; + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + setOutputType(op, sig); + } + { + auto* op = block->add_operations(); + op->set_type("mul"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + inputs["y"].add_arguments()->set_name(sig); + setOutputType(op, output); + } +} + +void MILBuilder::setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* t = out->mutable_type()->mutable_tensortype(); + t->set_datatype(m_weight_dtype); + t->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + auto* dim = t->add_dimensions(); + if (d < 0) + dim->mutable_unknown()->set_variadic(false); + else + dim->mutable_constant()->set_size(d); } } @@ -887,23 +1115,37 @@ void MILBuilder::addMatMulOp(CoreML::Specification::MILSpec::Block* block, CoreML::Specification::MILSpec::DataType::BOOL); } + // Non-spatial matmul in FP32 (KataGo FP16 convention; weights cast up at runtime, stored fp16). + std::string mmIn = input, mmW = weight_name, mmOut = output; + if (m_nonspatial_fp32) { + mmIn = castFixed(block, input, "fp32", {-1, layer.in_channels}); + mmW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + mmOut = output + "_mmf32"; + } + ScopedFp32 fp32Scope(m_weight_dtype, m_nonspatial_fp32); + // Add matmul operation auto* op = block->add_operations(); op->set_type("matmul"); auto& inputs = *op->mutable_inputs(); inputs["transpose_x"].add_arguments()->set_name(transpose_x_name); inputs["transpose_y"].add_arguments()->set_name(transpose_y_name); - inputs["x"].add_arguments()->set_name(input); - inputs["y"].add_arguments()->set_name(weight_name); + inputs["x"].add_arguments()->set_name(mmIn); + inputs["y"].add_arguments()->set_name(mmW); // Output with 2D shape [batch, out_channels] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(mmOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(2); setBatchDimension(out_type); out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); + + if (m_nonspatial_fp32) { + fp32Scope.restore(); + addCastOp(block, mmOut, output, "fp16", {-1, layer.out_channels}); + } } void MILBuilder::addMatBiasOp(CoreML::Specification::MILSpec::Block* block, @@ -958,13 +1200,15 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, // Add transposed weight constant with shape [out_channels, in_channels] std::vector transposed_shape = {static_cast(out_ch), static_cast(in_ch)}; - addConstOp(block, weight_name, transposed_weights, transposed_shape); + addOwnedConstOp(block, weight_name, std::move(transposed_weights), transposed_shape); // Add bias constant std::vector bias_shape = {static_cast(bias.num_channels)}; addConstOp(block, bias_name, bias.weights, bias_shape); - // Add linear operation + // NOTE: the MIL `linear` op requires const weight/bias, so the runtime-cast-to-FP32 trick can't + // be applied here (unlike `matmul`). Value-head linear stays FP16; if a model ever needs it in + // FP32, rewrite as matmul+add (matmul accepts cast inputs). auto* op = block->add_operations(); op->set_type("linear"); auto& inputs = *op->mutable_inputs(); @@ -1637,6 +1881,635 @@ void MILBuilder::addGlobalPoolingValueOps(CoreML::Specification::MILSpec::Block* // Network Component Builders // ============================================================================ +// --------------------------------------------------------------------------- +// Transformer blocks (MIL). Layout is NCHW [B, C, H, W]; spatial positions +// (H*W, ordered y*W+x) are treated as the attention sequence. RoPE is applied +// via a fixed pair-rotation matmul plus host-precomputed cos/sin tables, which +// keeps every tensor rank <= 4 (ANE-friendly). +// --------------------------------------------------------------------------- + +std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + // RMSNorm reduction core: square -> mean over channels -> rsqrt. In FP16 mode compute this + // core in FP32 (cast input up, flip the working dtype so the core's op outputs + eps scalar are + // FP32, then cast 1/rms back down). The FP16 channel reduction loses too much precision on the + // ANE; only this core is FP32 - the scaling/weight/mask below stay FP16. No addConstOp lives in + // the flipped window, so weight serialization is unaffected. + std::string sqSrc = input; + if (m_use_fp16) { + sqSrc = genVarName(prefix + "_in32"); + addCastOp(block, input, sqSrc, "fp32", {-1, C, H, W}); + } + ScopedFp32 fp32Scope(m_weight_dtype, m_use_fp16); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", sqSrc, sqSrc, sq, {-1, C, H, W}); + // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so + // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can + // overflow FP16 (and the FP16 accumulation on ANE) for large activations. + std::string meanSq = genVarName(prefix + "_meansq"); + { + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + } + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, {-1, 1, H, W}); + } + std::string inv = invCore; + if (m_use_fp16) { + fp32Scope.restore(); + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", {-1, 1, H, W}); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string weightName = prefix + "_weight"; + addConstOp(block, weightName, desc.weight, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, weightName, scaled, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + emit2("mul", scaled, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + auto reduceSum = [&](const std::string& x, const std::string& out, const std::vector& axes, + const std::vector& dims) { + std::string axesName = out + "_axes"; + std::string keepName = out + "_keep"; + addIntArrayConstOp(block, axesName, axes); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_sum"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, out, dims); + }; + + // Variance core (mask -> square -> reduce -> rsqrt) in FP32 when in FP16 mode. The trunk-tip + // norm in particular reduces over many elements and loses too much precision in FP16 on the + // ANE; compute the core in FP32 and cast 1/rms back to FP16. Only the core is FP32 - gamma/beta, + // the activation and the final mask below stay FP16. No addConstOp lives in the flipped window. + std::string tinput = input; + std::string tmask = mask; + if (m_use_fp16) { + tinput = genVarName(prefix + "_in32"); + addCastOp(block, input, tinput, "fp32", {-1, C, H, W}); + tmask = genVarName(prefix + "_mask32"); + addCastOp(block, mask, tmask, "fp32", {-1, 1, H, W}); + } + ScopedFp32 fp32Scope(m_weight_dtype, m_use_fp16); + std::string masked = genVarName(prefix + "_premask"); + emit2("mul", tinput, tmask, masked, {-1, C, H, W}); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", masked, masked, sq, {-1, C, H, W}); + + std::string meanSq; + std::vector denomDims; + if (desc.spatial) { + // Mean of squares over valid positions and channels. A reduce_sum over C*H*W elements + // overflows FP16 (e.g. trunk tip with large activations on ANE -> inf -> rsqrt 0 -> + // collapse). Instead take reduce_mean over all of C,H,W (masked positions are zero) and + // rescale by totalPositions/validCount to restrict the mean to valid positions. + std::string meanAll = genVarName(prefix + "_meanall"); + { + std::string axesName = meanAll + "_axes", keepName = meanAll + "_keep"; + addIntArrayConstOp(block, axesName, {1, 2, 3}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanAll, {-1, 1, 1, 1}); + } + std::string count = genVarName(prefix + "_count"); + reduceSum(tmask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + std::string totalPosName = prefix + "_totalpos"; + addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); + std::string scaleF = genVarName(prefix + "_scalef"); + emit2("real_div", totalPosName, count, scaleF, {-1, 1, 1, 1}); // totalPos / validCount + meanSq = genVarName(prefix + "_meansq"); + emit2("mul", meanAll, scaleF, meanSq, {-1, 1, 1, 1}); + denomDims = {-1, 1, 1, 1}; + } else { + meanSq = genVarName(prefix + "_meansq"); + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + denomDims = {-1, 1, H, W}; + } + + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, denomDims); + } + std::string inv = invCore; + if (m_use_fp16) { + fp32Scope.restore(); + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", denomDims); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string gammaName = prefix + "_gamma"; + std::string betaName = prefix + "_beta"; + addConstOp(block, gammaName, desc.gamma, {1, static_cast(C), 1, 1}); + addConstOp(block, betaName, desc.beta, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, gammaName, scaled, {-1, C, H, W}); + std::string biased = genVarName(prefix + "_biased"); + emit2("add", scaled, betaName, biased, {-1, C, H, W}); + + std::string activated; + if (act.activation_type == ActivationType::Silu) { + activated = genVarName(prefix + "_act"); + addSiluOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::Mish) { + activated = genVarName(prefix + "_act"); + addMishOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::ReLU) { + activated = genVarName(prefix + "_act"); + auto* op = block->add_operations(); + op->set_type("relu"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(biased); + setShape(op, activated, {-1, C, H, W}); + } else { + activated = biased; + } + std::string out = genVarName(prefix + "_out"); + emit2("mul", activated, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.q_proj.in_channels; + const int H = m_board_y_size, W = m_board_x_size; + const int seq = H * W; + const int numHeads = desc.num_heads, numKVHeads = desc.num_kv_heads; + const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; + const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims, bool transX, bool transY) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, transX); + addBoolScalarConstOp(block, tyName, transY); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + // Q/K/V projection matmuls in FP32 (non-spatial, per KataGo's FP16 convention): they reduce over + // C channels and the ANE's FP16 accumulation loses too much precision for wide models. Weights + // stay fp16-stored (cast up at runtime); output cast back to FP16 for the FP16 head reshapes. + auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { + std::string wName = nm + "_w"; + addConstOp(block, wName, w.weights, w.getWeightShape()); + std::string out = genVarName(nm); + if (m_nonspatial_fp32) { + std::string x32 = castFixed(block, x2d, "fp32", {-1, C}); + std::string w32 = castFixed(block, wName, "fp32", w.getWeightShape()); + std::string o32 = genVarName(nm + "_f32"); + { ScopedFp32 fp32Scope(m_weight_dtype); matmul(x32, w32, o32, {-1, total}, false, false); } + out = castFixed(block, o32, "fp16", {-1, total}); + } else { + matmul(x2d, wName, out, {-1, total}, false, false); + } + return out; + }; + std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); + std::string k2d = proj(desc.k_proj, prefix + "_k", kTotal); + std::string v2d = proj(desc.v_proj, prefix + "_v", vTotal); + auto toHeads = [&](const std::string& in2d, const std::string& nm, int nh, int hd) { + std::string r = genVarName(nm + "_r"); + reshape(in2d, r, {-1, seq, nh, hd}, {-1, seq, nh, hd}); + std::string t = genVarName(nm + "_t"); + transpose(r, t, {0, 2, 1, 3}, {-1, nh, seq, hd}); + return t; + }; + std::string qh = toHeads(q2d, prefix + "_qh", numHeads, qHeadDim); + std::string kh = toHeads(k2d, prefix + "_kh", numKVHeads, qHeadDim); + std::string vh = toHeads(v2d, prefix + "_vh", numKVHeads, vHeadDim); + + if (desc.use_rope) { + const int numPairs = qHeadDim / 2; + const int numPairsPerDim = numPairs / 2; + const int dimHalf = qHeadDim / 2; + auto applyRope = [&](const std::string& x, int nh, const std::string& tag) { + std::vector cosFull(static_cast(nh) * seq * qHeadDim, 0.0f); + std::vector sinFull(static_cast(nh) * seq * qHeadDim, 0.0f); + for (int h = 0; h < nh; h++) { + int kvh = (h * numKVHeads) / nh; + for (int xy = 0; xy < seq; xy++) { + int y = xy / W; + int x = xy % W; + for (int p = 0; p < numPairs; p++) { + float angle = 0.0f; + if (desc.learnable_rope) { + float fx = desc.rope_freqs[(kvh * numPairs + p) * 2 + 0]; + float fy = desc.rope_freqs[(kvh * numPairs + p) * 2 + 1]; + angle = static_cast(x) * fx + static_cast(y) * fy; + } else { + if (p < numPairsPerDim) { + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * p) / dimHalf); + angle = static_cast(y) * freq; + } else { + int pAdj = p - numPairsPerDim; + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * pAdj) / dimHalf); + angle = static_cast(x) * freq; + } + } + float c = std::cos(angle), s = std::sin(angle); + size_t base = (static_cast(h) * seq + xy) * qHeadDim + 2 * p; + cosFull[base] = c; cosFull[base + 1] = c; + sinFull[base] = s; sinFull[base + 1] = s; + } + } + } + std::vector R(static_cast(qHeadDim) * qHeadDim, 0.0f); + for (int p = 0; p < numPairs; p++) { + R[(2 * p) * qHeadDim + (2 * p + 1)] = 1.0f; + R[(2 * p + 1) * qHeadDim + (2 * p)] = -1.0f; + } + std::string cosName = prefix + "_" + tag + "_cos"; + std::string sinName = prefix + "_" + tag + "_sin"; + std::string rName = prefix + "_" + tag + "_R"; + // cosFull/sinFull/R are locals computed here, so register them as OWNED consts: the + // WeightEntry holds a non-owning FloatView and serialization runs after this lambda + // returns, so a non-owning addConstOp would dangle. + addOwnedConstOp(block, cosName, std::move(cosFull), {1, nh, seq, qHeadDim}); + addOwnedConstOp(block, sinName, std::move(sinFull), {1, nh, seq, qHeadDim}); + // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. + addOwnedConstOp(block, rName, std::move(R), {1, 1, qHeadDim, qHeadDim}); + std::string rotated = genVarName(prefix + "_" + tag + "_rot"); + matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); + std::string xc = genVarName(prefix + "_" + tag + "_xc"); + binary("mul", x, cosName, xc, {-1, nh, seq, qHeadDim}); + std::string rs = genVarName(prefix + "_" + tag + "_rs"); + binary("mul", rotated, sinName, rs, {-1, nh, seq, qHeadDim}); + std::string out = genVarName(prefix + "_" + tag + "_rope"); + binary("add", xc, rs, out, {-1, nh, seq, qHeadDim}); + return out; + }; + qh = applyRope(qh, numHeads, "q"); + kh = applyRope(kh, numKVHeads, "k"); + } + + // GQA: when numKVHeads < numHeads, repeat each KV head groupSize times along the head + // axis (axis 1) so query head h consumes kv head (h / groupSize). RoPE has already been + // applied above to the unexpanded kh (kh = applyRope(kh, numKVHeads, "k")), mirroring the + // GPU path (metallayers.swift repeatKVHeads runs AFTER applyRope). We slice each KV head + // and concat its copies consecutively, so the resulting head index is kv*groupSize + g; + // query head h then maps to kv = h/groupSize == (h*numKVHeads)/numHeads (exact divisor, + // the same formula the qh RoPE table uses) == Eigen's kvh = h/kvGroupSize. slice_by_size + + // concat (not reshape+broadcast) avoids the dynamic -1 batch broadcast pitfall, same as the + // GPU code. The repeat is required so the scores (qh@kh^T) and attnOut (attn@vh) matmuls see + // matching [B,numHeads,...] batch dims instead of numHeads vs numKVHeads (no broadcast). + if (numKVHeads != numHeads) { + const int groupSize = numHeads / numKVHeads; + auto repeatKVHeads = [&](const std::string& x, const std::string& tag, int headDim) { + std::vector parts; + parts.reserve(static_cast(numKVHeads) * groupSize); + for (int kv = 0; kv < numKVHeads; kv++) { + for (int g = 0; g < groupSize; g++) { + std::string part = genVarName(prefix + "_" + tag + "_slc"); + std::string beginName = part + "_begin", sizeName = part + "_size"; + addIntArrayConstOp(block, beginName, {0, kv, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, headDim}); + auto* sop = block->add_operations(); + sop->set_type("slice_by_size"); + (*sop->mutable_inputs())["x"].add_arguments()->set_name(x); + (*sop->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*sop->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(sop, part, {-1, 1, seq, headDim}); + parts.push_back(part); + } + } + std::string out = genVarName(prefix + "_" + tag + "_exp"); + std::string axisName = out + "_axis", interleaveName = out + "_interleave"; + addIntScalarConstOp(block, axisName, 1); + addBoolScalarConstOp(block, interleaveName, false); + auto* cop = block->add_operations(); + cop->set_type("concat"); + auto& cin = *cop->mutable_inputs(); + for (const std::string& part : parts) + cin["values"].add_arguments()->set_name(part); + cin["axis"].add_arguments()->set_name(axisName); + cin["interleave"].add_arguments()->set_name(interleaveName); + setShape(cop, out, {-1, numHeads, seq, headDim}); + return out; + }; + kh = repeatKVHeads(kh, "khrep", qHeadDim); + vh = repeatKVHeads(vh, "vhrep", vHeadDim); + } + + std::string scores = genVarName(prefix + "_scores"); + matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); + std::string scaleName = prefix + "_scale"; + addFloatScalarConstOp(block, scaleName, 1.0f / std::sqrt(static_cast(qHeadDim))); + std::string scaled = genVarName(prefix + "_sc"); + binary("mul", scores, scaleName, scaled, {-1, numHeads, seq, seq}); + + // mask [B,1,H,W] -> [B,1,1,seq] directly (contiguous reshape; H,W already trailing so the + // row-major flatten gives seq index xy=y*W+x). No transpose -> avoids the reshape-after- + // transpose issue, and is also correct for non-full boards. + std::string maskSeq = genVarName(prefix + "_mseq"); + reshape(mask, maskSeq, {-1, 1, 1, seq}, {-1, 1, 1, seq}); + std::string oneName = prefix + "_one"; + addFloatScalarConstOp(block, oneName, 1.0f); + std::string mm1 = genVarName(prefix + "_mm1"); + binary("sub", maskSeq, oneName, mm1, {-1, 1, 1, seq}); + // Use an FP16-safe magnitude: 1e9 overflows FP16 to +inf, and for valid keys + // (maskSeq-1 == 0) the product 0 * inf becomes NaN, poisoning the whole softmax. + // 1e4 is well within FP16 range and exp(score - 1e4) still underflows to 0. + std::string bigName = prefix + "_big"; + addFloatScalarConstOp(block, bigName, 1.0e4f); + std::string keyBias = genVarName(prefix + "_kb"); + binary("mul", mm1, bigName, keyBias, {-1, 1, 1, seq}); + std::string scoresMasked = genVarName(prefix + "_scm"); + binary("add", scaled, keyBias, scoresMasked, {-1, numHeads, seq, seq}); + + std::string attn = genVarName(prefix + "_attn"); + { + std::string axisName = attn + "_axis"; + addIntScalarConstOp(block, axisName, 3); + auto* op = block->add_operations(); + op->set_type("softmax"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(scoresMasked); + (*op->mutable_inputs())["axis"].add_arguments()->set_name(axisName); + setShape(op, attn, {-1, numHeads, seq, seq}); + } + + std::string attnOut = genVarName(prefix + "_ao"); + matmul(attn, vh, attnOut, {-1, numHeads, seq, vHeadDim}, false, false); + + // Output projection, done per-head to avoid reshape-after-transpose: CoreML's reshape + // ignores an immediately-preceding transpose, so merging [head,dim]->channels after a + // transpose scrambles the data. Instead slice each head from attnOut (head is the + // contiguous axis 1), reshape (leading-merge only), matmul its weight slice, and sum. + // out[b,s,c] = sum_h sum_d attnOut[b,h,s,d] * outProj.weights[(h*vHeadDim+d)*outC + c] + const int outC = desc.out_proj.out_channels; + std::string proj2d; + for (int h = 0; h < numHeads; h++) { + std::string aoh = genVarName(prefix + "_aoh"); + { + std::string beginName = aoh + "_begin", sizeName = aoh + "_size"; + addIntArrayConstOp(block, beginName, {0, h, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, vHeadDim}); + auto* op = block->add_operations(); + op->set_type("slice_by_size"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(attnOut); + (*op->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*op->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(op, aoh, {-1, 1, seq, vHeadDim}); + } + std::string aoh2d = genVarName(prefix + "_aoh2d"); + reshape(aoh, aoh2d, {-1, vHeadDim}, {-1, vHeadDim}); // [B*seq, vHeadDim] + std::string wh = prefix + "_ow" + std::to_string(h); + std::vector whData(static_cast(vHeadDim) * outC); + for (int d = 0; d < vHeadDim; d++) + for (int c = 0; c < outC; c++) + whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; + // whData is a per-head local slice; register OWNED so its FloatView stays valid until + // serialization (a non-owning addConstOp would dangle after this loop iteration). + addOwnedConstOp(block, wh, std::move(whData), {vHeadDim, outC}); + std::string contrib = genVarName(prefix + "_contrib"); + matmul(aoh2d, wh, contrib, {-1, outC}, false, false); + if (h == 0) { + proj2d = contrib; + } else { + std::string acc = genVarName(prefix + "_acc"); + binary("add", proj2d, contrib, acc, {-1, outC}); + proj2d = acc; + } + } + std::string projNHWC = genVarName(prefix + "_pnhwc"); + reshape(proj2d, projNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string projNCHW = genVarName(prefix + "_pnchw"); + transpose(projNHWC, projNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", projNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int ffn = desc.ffn_channels; + const int H = m_board_y_size, W = m_board_x_size; + + if (!desc.use_swiglu) { + throw std::runtime_error(desc.name + ": non-SwiGLU transformer FFN not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, false); + addBoolScalarConstOp(block, tyName, false); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + + // FFN matmuls in FP32 (weights cast up at runtime, stored fp16) — KataGo's FP16 convention is + // spatial(convs)=FP16, non-spatial(matmuls)=FP32 (see openclbackend.cpp). The ANE accumulates + // FP16 matmuls in FP16, which loses too much precision over C/ffn; run them in FP32 instead. + std::string w1 = prefix + "_w1"; + addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); + std::string wg = prefix + "_wg"; + addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + + std::string mx2d = x2d, mw1 = w1, mwg = wg, mw2 = w2; + if (m_nonspatial_fp32) { + mx2d = castFixed(block, x2d, "fp32", {-1, C}); + mw1 = castFixed(block, w1, "fp32", desc.linear1.getWeightShape()); + mwg = castFixed(block, wg, "fp32", desc.linear_gate.getWeightShape()); + mw2 = castFixed(block, w2, "fp32", desc.linear2.getWeightShape()); + } + ScopedFp32 fp32Scope(m_weight_dtype, m_nonspatial_fp32); + std::string a = genVarName(prefix + "_a"); + matmul(mx2d, mw1, a, {-1, ffn}); + std::string g = genVarName(prefix + "_g"); + matmul(mx2d, mwg, g, {-1, ffn}); + + std::string sig = genVarName(prefix + "_sig"); + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(a); + setShape(op, sig, {-1, ffn}); + } + std::string siluA = genVarName(prefix + "_silu"); + binary("mul", a, sig, siluA, {-1, ffn}); + std::string h = genVarName(prefix + "_h"); + binary("mul", siluA, g, h, {-1, ffn}); + + std::string oCore = genVarName(prefix + "_o"); + matmul(h, mw2, oCore, {-1, C}); + std::string o = oCore; + if (m_nonspatial_fp32) { + fp32Scope.restore(); + o = castFixed(block, oCore, "fp16", {-1, C}); + } + + std::string oNHWC = genVarName(prefix + "_onhwc"); + reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string oNCHW = genVarName(prefix + "_onchw"); + transpose(oNHWC, oNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", oNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, const std::string& spatial_input, const std::string& global_input, @@ -1747,12 +2620,23 @@ std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, } else if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { const auto& block_desc = std::get(*entry.block); x = buildNestedBottleneckBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, block_desc, mask, prefix); } } // Trunk tip - std::string trunk_out = genVarName("trunk_tip"); - addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + std::string trunk_out; + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk_out = genVarName("trunk_tip"); + addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + } else { + trunk_out = addTrunkRMSNorm(block, x, trunk.trunk_tip_rms_norm, trunk.trunk_tip_activation, mask, "trunk_tip_rms"); + } return trunk_out; } @@ -1814,9 +2698,11 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling + // Global pooling (FP32 when m_nonspatial_fp32 -- see addGlobalPoolingFp32). Feeds a bias back + // into the whole trunk, so the FP16 spatial sum must not lose precision for wide trunks. std::string gpool_features = genVarName(prefix + "_gpool_features"); - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + addGlobalPoolingFp32(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features, + /*valueVariant=*/false); // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -1898,6 +2784,12 @@ std::string MILBuilder::buildNestedBottleneckBlock(CoreML::Specification::MILSpe } else if (entry.block_kind == GLOBAL_POOLING_BLOCK_KIND) { const auto& nested = std::get(*entry.block); x = buildGlobalPoolingResidualBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, nested, mask, nested_prefix); } } @@ -1942,9 +2834,9 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 + // Global pooling on G1 (FP32 when m_nonspatial_fp32; feeds the policy bias / policyKLDiv). std::string g1_pooled = genVarName("policy_g1_pool"); - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + addGlobalPoolingFp32(block, g1, mask, ph.g1_conv.out_channels, g1_pooled, /*valueVariant=*/false); // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2002,6 +2894,8 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, pass_activated, ph.gpool_to_pass_mul.out_channels); } else if (ph.pass_activation->activation_type == ActivationType::Mish) { addMishOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); + } else if (ph.pass_activation->activation_type == ActivationType::Silu) { + addSiluOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); } else { pass_activated = pass_biased; } @@ -2032,9 +2926,9 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) + // Global pooling (value head version; FP32 when m_nonspatial_fp32). std::string v1_pooled = genVarName("value_v1_pool"); - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + addGlobalPoolingFp32(block, v1, mask, vh.v1_conv.out_channels, v1_pooled, /*valueVariant=*/true); // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); @@ -2049,6 +2943,8 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, v2, vh.v2_mul.out_channels); } else if (vh.v2_activation.activation_type == ActivationType::Mish) { addMishOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); + } else if (vh.v2_activation.activation_type == ActivationType::Silu) { + addSiluOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); } else { v2 = v2_bias; } @@ -2085,6 +2981,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act1, encoder.mul1.out_channels); } else if (encoder.act1.activation_type == ActivationType::Mish) { addMishOps(block, bias1, act1, 2, encoder.mul1.out_channels); + } else if (encoder.act1.activation_type == ActivationType::Silu) { + addSiluOps(block, bias1, act1, 2, encoder.mul1.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); @@ -2107,6 +3005,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act2, encoder.mul2.out_channels); } else if (encoder.act2.activation_type == ActivationType::Mish) { addMishOps(block, bias2, act2, 2, encoder.mul2.out_channels); + } else if (encoder.act2.activation_type == ActivationType::Silu) { + addSiluOps(block, bias2, act2, 2, encoder.mul2.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..de3576f6a 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -29,8 +29,13 @@ class MILBuilder { /// @return Unique pointer to MIL Program protobuf std::unique_ptr build(); - /// Get weight entries for blob serialization - const std::vector& getWeights() const { return m_ops.getWeights(); } + /// Get weight entries for blob serialization (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_ops.getWeightsMutable(); } + + /// Effective FP16-IO decision after any narrow-transformer FP32 downgrade in + /// the constructor. The serializer must use this (not the raw request) so the + /// feature-descriptor IO dtype matches the MIL program's actual IO tensors. + bool getUseFp16Io() const { return m_use_fp16_io; } /// Get board dimensions int getBoardXSize() const { return m_board_x_size; } @@ -43,6 +48,16 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; + // FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer + // trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation). + // Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass + // testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32. + // For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls + + // pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; @@ -80,6 +95,24 @@ class MILBuilder { const std::vector& data, const std::vector& shape); + // addConstOp registers a NON-OWNING view into `data` (see WeightEntry), so the + // backing storage must outlive serialization. Binding a temporary here would + // dangle. Deleted so such calls fail to compile; use addOwnedConstOp for + // derived/temporary tensors that KataGoOps should own instead. + void addConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + + void addOwnedConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + std::vector&& data, + const std::vector& shape); + + void emitConstOp(CoreML::Specification::MILSpec::Block* block, + const std::string& name, + const std::vector& shape); + void addIntArrayConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& values); @@ -102,6 +135,23 @@ class MILBuilder { const std::string& dtype, const std::vector& shape); + // Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for + // weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the + // new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim. + std::string castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims); + + // Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool, + // cast the pooled features back to FP16). valueVariant selects the value-head pooling variant. + void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -120,6 +170,44 @@ class MILBuilder { int rank, int channels); + void addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels); + + // Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions. + void setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims); + + // Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out. + std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix); + + // Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out. + std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block, const std::string& input, const std::string& mask, diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..e86364943 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -14,12 +14,30 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { WeightEntry entry; entry.name = name; - entry.data = data; + entry.data = FloatView{data.data(), data.size()}; entry.shape = shape; entry.blob_offset = 0; // Will be set during serialization + entry.is_fp32 = is_fp32; + m_weights.push_back(std::move(entry)); + return name; +} + +std::string KataGoOps::registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape, + bool is_fp32) { + m_owned.push_back(std::move(data)); + const std::vector& stored = m_owned.back(); + WeightEntry entry; + entry.name = name; + entry.data = FloatView{stored.data(), stored.size()}; + entry.shape = shape; + entry.blob_offset = 0; + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..50d148311 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -5,17 +5,33 @@ #include "../types/KataGoTypes.hpp" #include +#include #include #include namespace katagocoreml { -/// Weight entry for blob file storage +/// Minimal non-owning view over a contiguous float buffer. KataGo-local on +/// purpose: keeps the MILBlob dependency out of this header (conversion to +/// MILBlob::Util::Span happens only at the serializer boundary). +struct FloatView { + const float* ptr = nullptr; + size_t len = 0; + const float* data() const { return ptr; } + size_t size() const { return len; } + bool empty() const { return len == 0; } + float operator[](size_t i) const { return ptr[i]; } +}; + +/// Weight entry for blob file storage. `data` is a NON-OWNING view into the live +/// KataGoModelDesc (or into KataGoOps::m_owned for derived tensors). struct WeightEntry { std::string name; - std::vector data; + FloatView data; // non-owning view (replaces raw ptr + count) std::vector shape; uint64_t blob_offset = 0; // Set during serialization + bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an + // FP32 sub-region of an otherwise-FP16 model). Else stored per global mode. }; /// Precomputed constants for identity mask optimization @@ -51,16 +67,41 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight that lives in the model (stored as a non-owning view). + /// is_fp32 marks it for FP32 storage. std::string registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); + + /// The stored WeightEntry is a non-owning view into `data`, so a temporary + /// would leave it dangling. Deleted to reject such calls at compile time; + /// use registerOwnedWeight for tensors KataGoOps should own. + std::string registerWeight(const std::string& name, + std::vector&& data, + const std::vector& shape) = delete; + // Also block the is_fp32-bearing call shape: without this overload, a call + // like registerWeight(name, std::move(vec), shape, true) has 4 arguments and + // would NOT match the 3-arg deleted overload above, instead binding the + // temporary to the const-ref live overload and leaving a dangling view. + std::string registerWeight(const std::string& name, + std::vector&& data, + const std::vector& shape, + bool is_fp32) = delete; + + /// Register a derived/temporary weight; KataGoOps takes ownership so the + /// view stays valid through serialization. is_fp32 marks it for FP32 storage + /// (mirrors registerWeight) so the stored dtype matches the declared const dtype. + std::string registerOwnedWeight(const std::string& name, + std::vector&& data, + const std::vector& shape, + bool is_fp32 = false); - /// Get all registered weights - const std::vector& getWeights() const { return m_weights; } + /// Get all registered weights (mutable; serialization sets blob_offset) + std::vector& getWeightsMutable() { return m_weights; } - /// Clear all registered weights - void clearWeights() { m_weights.clear(); } + /// Clear all registered weights (and their owned backing buffers) + void clearWeights() { m_weights.clear(); m_owned.clear(); } /// Generate unique operation name std::string genOpName(const std::string& prefix); @@ -71,6 +112,7 @@ class KataGoOps { bool m_optimize_identity_mask; MaskConstants m_mask_constants; std::vector m_weights; + std::deque> m_owned; int m_op_counter = 0; }; diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 68f1a0e56..c6dadc5c2 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -30,54 +29,41 @@ bool KataGoParser::isVersionSupported(int version) { } // ============================================================================ -// File Loading +// Stream Primitives // ============================================================================ -void KataGoParser::loadFile() { - // Check if gzip compressed - bool is_gzip = false; - if (m_model_path.size() >= 3) { - std::string ext = m_model_path.substr(m_model_path.size() - 3); - is_gzip = (ext == ".gz"); - } - - if (is_gzip) { - // Read gzipped file - gzFile gz = gzopen(m_model_path.c_str(), "rb"); - if (!gz) { - throw std::runtime_error("Cannot open gzip file: " + m_model_path); - } - - // Read in chunks - m_buffer.clear(); - std::vector chunk(1024 * 1024); // 1MB chunks - int bytes_read; - while ((bytes_read = gzread(gz, chunk.data(), static_cast(chunk.size()))) > 0) { - m_buffer.insert(m_buffer.end(), chunk.begin(), chunk.begin() + bytes_read); - } - - if (bytes_read < 0) { - int errnum; - const char* errmsg = gzerror(gz, &errnum); - gzclose(gz); - throw std::runtime_error("Error reading gzip file: " + std::string(errmsg)); - } - - gzclose(gz); - } else { - // Read regular file - std::ifstream file(m_model_path, std::ios::binary | std::ios::ate); - if (!file) { - throw std::runtime_error("Cannot open file: " + m_model_path); - } +bool KataGoParser::refill() { + if(!m_gz) return false; + int n = gzread(m_gz.get(), m_refill.data(), (unsigned)m_refill.size()); + if(n < 0) { + int errnum; + const char* errmsg = gzerror(m_gz.get(), &errnum); + throw std::runtime_error("Error reading gzip stream: " + std::string(errmsg)); + } + m_refillPos = 0; + m_refillLen = (size_t)n; + return n > 0; +} - std::streamsize size = file.tellg(); - file.seekg(0, std::ios::beg); +int KataGoParser::peekByte() { + if(m_refillPos >= m_refillLen) { + if(!refill()) return -1; + } + return (int)m_refill[m_refillPos]; +} - m_buffer.resize(static_cast(size)); - if (!file.read(reinterpret_cast(m_buffer.data()), size)) { - throw std::runtime_error("Error reading file: " + m_model_path); +void KataGoParser::readExact(uint8_t* dst, size_t n, const std::string& name) { + size_t got = 0; + while(got < n) { + if(m_refillPos >= m_refillLen) { + if(!refill()) + throw std::runtime_error(name + ": unexpected EOF in binary block"); } + size_t avail = m_refillLen - m_refillPos; + size_t take = std::min(avail, n - got); + std::memcpy(dst + got, m_refill.data() + m_refillPos, take); + m_refillPos += take; + got += take; } } @@ -86,15 +72,16 @@ void KataGoParser::loadFile() { // ============================================================================ KataGoModelDesc KataGoParser::parse() { - loadFile(); - m_pos = 0; - - // Detect if binary format (check for @BIN@ marker) - const std::string bin_marker = "@BIN@"; - auto it = std::search(m_buffer.begin(), m_buffer.end(), - bin_marker.begin(), bin_marker.end()); - m_binary_floats = (it != m_buffer.end()); - + // Allocate the refill buffer first; if this throws, no handle has been opened. + m_refill.resize(1024 * 1024); + m_gz.reset(gzopen(m_model_path.c_str(), "rb")); + if(!m_gz) + throw std::runtime_error("Cannot open file: " + m_model_path); + m_refillPos = 0; + m_refillLen = 0; + m_formatDetected = false; // decided at first readFloats + m_binary_floats = true; + // ~GzHandle closes the file on normal return OR exception — no try/catch needed. return parseModel(); } @@ -103,24 +90,20 @@ KataGoModelDesc KataGoParser::parse() { // ============================================================================ void KataGoParser::skipWhitespace() { - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c != ' ' && c != '\t' && c != '\n' && c != '\r') { - break; - } - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c != ' ' && c != '\t' && c != '\n' && c != '\r') break; + m_refillPos++; } } void KataGoParser::readUntilWhitespace(std::string& out) { out.clear(); - while (m_pos < m_buffer.size()) { - char c = static_cast(m_buffer[m_pos]); - if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { - break; - } - out += c; - m_pos++; + int c; + while((c = peekByte()) >= 0) { + if(c == ' ' || c == '\t' || c == '\n' || c == '\r') break; + out += (char)c; + m_refillPos++; } } @@ -147,37 +130,28 @@ bool KataGoParser::readBool() { std::vector KataGoParser::readFloats(size_t count, const std::string& name) { std::vector floats(count); + skipWhitespace(); + + // KataGo model files are uniformly text OR uniformly binary, so detecting the + // format once at the first weight block (binary blocks start with '@BIN@') + // is valid for all subsequent blocks. + if(!m_formatDetected) { + m_binary_floats = (peekByte() == '@'); + m_formatDetected = true; + } - if (!m_binary_floats) { + if(!m_binary_floats) { // Text format - for (size_t i = 0; i < count; i++) { + for(size_t i = 0; i < count; i++) floats[i] = readFloat(); - } } else { - // Binary format - find @BIN@ marker - while (m_pos < m_buffer.size()) { - if (m_buffer[m_pos] == '@') { - break; - } - m_pos++; - } - - // Check for @BIN@ header - if (m_pos + 5 > m_buffer.size() || - std::memcmp(&m_buffer[m_pos], "@BIN@", 5) != 0) { + // Binary: consume the "@BIN@" marker, then read count*4 raw bytes. + char marker[5]; + readExact(reinterpret_cast(marker), 5, name); + if(std::memcmp(marker, "@BIN@", 5) != 0) throw std::runtime_error(name + ": expected @BIN@ marker for binary float block"); - } - m_pos += 5; - // Read binary floats (little-endian) - size_t num_bytes = count * 4; - if (m_pos + num_bytes > m_buffer.size()) { - throw std::runtime_error(name + ": not enough bytes for " + std::to_string(count) + " floats"); - } - - // Copy as little-endian float32 - std::memcpy(floats.data(), &m_buffer[m_pos], num_bytes); - m_pos += num_bytes; + readExact(reinterpret_cast(floats.data()), count * 4, name); } // Reject NaN/Inf weights: corrupted or otherwise invalid models would @@ -315,6 +289,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) { layer.activation_type = ActivationType::ReLU; } else if (activation_str == "ACTIVATION_MISH") { layer.activation_type = ActivationType::Mish; + } else if (activation_str == "ACTIVATION_SILU") { + layer.activation_type = ActivationType::Silu; } else { throw std::runtime_error("Unknown activation type: " + activation_str); } @@ -332,6 +308,14 @@ MatMulLayerDesc KataGoParser::parseMatMulLayer() { layer.in_channels = readInt(); layer.out_channels = readInt(); + // Match master desc.cpp's MatMulLayerDesc validation (inChannels/outChannels + // <= 0): reject non-positive dims so a malformed model fails loudly here + // rather than building a degenerate matmul or, on a negative value, + // computing a huge size_t allocation below. + if (layer.in_channels <= 0 || layer.out_channels <= 0) { + throw std::runtime_error(layer.name + ": invalid matmul in/out channels"); + } + // Weights in [ic, oc] order size_t num_weights = static_cast(layer.in_channels) * layer.out_channels; layer.weights = readFloats(num_weights, layer.name); @@ -343,6 +327,12 @@ MatBiasLayerDesc KataGoParser::parseMatBiasLayer() { MatBiasLayerDesc layer; layer.name = readString(); layer.num_channels = readInt(); + // Match master desc.cpp (matbias numChannels <= 0): fail loud before + // readFloats turns a negative into a huge size_t allocation. + if (layer.num_channels <= 0) { + throw std::runtime_error(layer.name + ": matbias numChannels must be >= 1, got " + + std::to_string(layer.num_channels)); + } layer.weights = readFloats(layer.num_channels, layer.name); return layer; @@ -387,6 +377,13 @@ NestedBottleneckResidualBlockDesc KataGoParser::parseNestedBottleneckBlock(int m NestedBottleneckResidualBlockDesc block; block.name = readString(); block.num_blocks = readInt(); + // Match master desc.cpp (nested numBlocks < 1): validate before parseBlockStack + // reserves on this count, so a crafted/corrupt model can't force a huge + // allocation (or a confusing length_error on a negative) from a tiny file. + if (block.num_blocks < 1) { + throw std::runtime_error(block.name + ": nested numBlocks must be >= 1, got " + + std::to_string(block.num_blocks)); + } block.pre_bn = parseBatchNormLayer(); block.pre_activation = parseActivationLayer(model_version); @@ -420,6 +417,167 @@ static void checkBlockChannels(const std::string& block_name, const std::string& } } +// Validate one transformer-attention projection matmul dimension against the head +// geometry / trunk width. The CoreML graph builder reshapes each projection's flat +// output into a [seq, heads, headDim] grid and reshapes the out-projection result +// back into the trunk, so a declared dimension that disagrees either reads past the +// weight buffer or builds a graph that compiles but computes nonsense. +// Match master desc.cpp's transformer attention consistency checks. +static void checkAttentionProjDim(const std::string& block_name, const std::string& dim_name, + int actual, const std::string& expected_name, int expected) { + if (actual != expected) { + throw std::runtime_error(block_name + ": " + dim_name + " (" + std::to_string(actual) + + ") != " + expected_name + " (" + std::to_string(expected) + ")"); + } +} + +TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() { + TransformerRMSNormDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1"); + } + // Match master desc.cpp (rmsnorm epsilon <= 0 || > 1.0f): reject a non-positive + // or too-large epsilon so a malformed model fails loudly here instead of + // producing a structurally-valid model that computes garbage via rsqrt(x+eps). + if (layer.epsilon <= 0.0f || layer.epsilon > 1.0f) { + throw std::runtime_error(layer.name + ": transformer rmsnorm epsilon must be in (0, 1]"); + } + layer.weight = readFloats(layer.num_channels, layer.name + "/weight"); + return layer; +} + +RMSNormLayerDesc KataGoParser::parseRMSNormLayer() { + RMSNormLayerDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + layer.spatial = (readInt() != 0); + layer.cgroup_size = readInt(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1"); + } + if (layer.cgroup_size != 0) { + throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported"); + } + // Match master desc.cpp (rmsnorm epsilon <= 0 || > 1.0f). + if (layer.epsilon <= 0.0f || layer.epsilon > 1.0f) { + throw std::runtime_error(layer.name + ": rmsnorm epsilon must be in (0, 1]"); + } + layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma"); + layer.beta = readFloats(layer.num_channels, layer.name + "/beta"); + return layer; +} + +TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version, int trunk_num_channels) { + TransformerAttentionBlockDesc block; + block.name = readString(); + block.num_heads = readInt(); + block.num_kv_heads = readInt(); + block.q_head_dim = readInt(); + block.v_head_dim = readInt(); + block.use_rope = (readInt() != 0); + block.learnable_rope = (readInt() != 0); + + if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) { + throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads"); + } + // Match master desc.cpp (qHeadDim < 1 || vHeadDim < 1): a degenerate head dim + // makes scale = 1/sqrt(0) = inf and reshape dims degenerate. The projection + // cross-checks below only catch *inconsistent* dims, not a consistently-zero one. + if (block.q_head_dim < 1 || block.v_head_dim < 1) { + throw std::runtime_error(block.name + ": qHeadDim and vHeadDim must be >= 1"); + } + if (block.use_rope && (block.q_head_dim % 2 != 0)) { + throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used"); + } + + block.pre_ln = parseTransformerRMSNorm(); + block.q_proj = parseMatMulLayer(); + block.k_proj = parseMatMulLayer(); + block.v_proj = parseMatMulLayer(); + block.out_proj = parseMatMulLayer(); + + // Validate the four projection matmul dimensions against the head geometry and + // trunk width. Six mirror master desc.cpp's transformer attention checks + // (desc.cpp:1129-1136, 1430-1443); the k/v inChannels checks additionally close + // a gap master leaves implicit to the backend - all three QKV projections + // consume the same normed-trunk input, so their inChannels must equal the trunk + // width. K pairs with Q in the QK^T dot product, so kProj uses qHeadDim; only V + // carries vHeadDim. + checkAttentionProjDim(block.name, "qProj.inChannels", block.q_proj.in_channels, "trunkNumChannels", trunk_num_channels); + checkAttentionProjDim(block.name, "qProj.outChannels", block.q_proj.out_channels, "numHeads*qHeadDim", block.num_heads * block.q_head_dim); + checkAttentionProjDim(block.name, "kProj.inChannels", block.k_proj.in_channels, "trunkNumChannels", trunk_num_channels); + checkAttentionProjDim(block.name, "kProj.outChannels", block.k_proj.out_channels, "numKVHeads*qHeadDim", block.num_kv_heads * block.q_head_dim); + checkAttentionProjDim(block.name, "vProj.inChannels", block.v_proj.in_channels, "trunkNumChannels", trunk_num_channels); + checkAttentionProjDim(block.name, "vProj.outChannels", block.v_proj.out_channels, "numKVHeads*vHeadDim", block.num_kv_heads * block.v_head_dim); + checkAttentionProjDim(block.name, "outProj.inChannels", block.out_proj.in_channels, "numHeads*vHeadDim", block.num_heads * block.v_head_dim); + checkAttentionProjDim(block.name, "outProj.outChannels", block.out_proj.out_channels, "trunkNumChannels", trunk_num_channels); + + if (block.use_rope) { + if (block.learnable_rope) { + readString(); // ropeFreqs name + block.rope_num_kv_heads = readInt(); + block.rope_num_pairs = readInt(); + int rope_dim2 = readInt(); + if (block.rope_num_kv_heads != block.num_kv_heads || + block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) { + throw std::runtime_error(block.name + ": invalid learnable rope header"); + } + block.rope_freqs = readFloats( + static_cast(block.rope_num_kv_heads) * block.rope_num_pairs * 2, + block.name + "/rope_freqs"); + } else { + readString(); // ropeTheta name + block.rope_theta = readFloat(); + // Match master desc.cpp (ropeTheta <= 0). Without this, a non-positive + // theta makes pow(theta, frac) NaN/Inf, which flows into the derived + // cos/sin RoPE tables. Those are builder-derived (not parsed) weights, + // so they bypass the readFloats NaN/Inf gate and would otherwise yield + // a structurally-valid model that silently computes garbage. + if (block.rope_theta <= 0.0f) { + throw std::runtime_error(block.name + ": ropeTheta must be > 0"); + } + } + } + return block; +} + +TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version, int trunk_num_channels) { + TransformerFFNBlockDesc block; + block.name = readString(); + block.num_channels = readInt(); + block.ffn_channels = readInt(); + block.use_swiglu = (readInt() != 0); + if (block.num_channels < 1 || block.ffn_channels < 1) { + throw std::runtime_error(block.name + ": transformer ffn channels must be positive"); + } + block.pre_ln = parseTransformerRMSNorm(); + block.linear1 = parseMatMulLayer(); + if (block.use_swiglu) { + block.linear_gate = parseMatMulLayer(); + } + block.linear2 = parseMatMulLayer(); + + // Validate the FFN matmul dimensions, mirroring the attention block's projection checks above. + // The block adds its output back into the trunk residually, so numChannels must equal the trunk + // width; the linear layers chain numChannels -> ffnChannels -> numChannels (with the SwiGLU gate + // also projecting numChannels -> ffnChannels). A declared dimension that disagrees builds a graph + // that either fails to compile or computes nonsense, so reject it at parse time. + checkAttentionProjDim(block.name, "ffn.numChannels", block.num_channels, "trunkNumChannels", trunk_num_channels); + checkAttentionProjDim(block.name, "linear1.inChannels", block.linear1.in_channels, "numChannels", block.num_channels); + checkAttentionProjDim(block.name, "linear1.outChannels", block.linear1.out_channels, "ffnChannels", block.ffn_channels); + if (block.use_swiglu) { + checkAttentionProjDim(block.name, "linearGate.inChannels", block.linear_gate.in_channels, "numChannels", block.num_channels); + checkAttentionProjDim(block.name, "linearGate.outChannels", block.linear_gate.out_channels, "ffnChannels", block.ffn_channels); + } + checkAttentionProjDim(block.name, "linear2.inChannels", block.linear2.in_channels, "ffnChannels", block.ffn_channels); + checkAttentionProjDim(block.name, "linear2.outChannels", block.linear2.out_channels, "numChannels", block.num_channels); + return block; +} + std::vector KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) { std::vector blocks; blocks.reserve(num_blocks); @@ -449,6 +607,14 @@ std::vector KataGoParser::parseBlockStack(int model_version, int num desc.pre_bn.num_channels, desc.post_conv.out_channels, trunk_num_channels); entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_attention_block") { + entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND; + auto desc = parseTransformerAttentionBlock(model_version, trunk_num_channels); + entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_ffn_block") { + entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND; + auto desc = parseTransformerFFNBlock(model_version, trunk_num_channels); + entry.block = std::make_shared(std::move(desc)); } else { throw std::runtime_error("Unknown block kind: " + block_kind_name); } @@ -506,15 +672,15 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) } // Version >= 15 writes the trunk norm kind followed by 5 unused int parameters. - // This CoreML parser only supports the standard trunk norm kind (0 = BatchNorm/BiasMask); - // RMSNorm (used by transformer/rmsnorm models, kind != 0) is not implemented here, so reject it - // defensively rather than silently parsing it as standard norm and producing wrong outputs. + // Unlike upstream's CoreML parser (which rejects any non-standard norm), this fork + // implements RMSNorm, so we capture the kind here instead of throwing. The 5 trailing + // ints are reserved and still expected to be zero. if (model_version >= 15) { - int trunk_norm_kind = readInt(); - if (trunk_norm_kind != 0) { - throw std::runtime_error(trunk.name + ": unsupported trunk norm kind " + - std::to_string(trunk_norm_kind) + - " (this CoreML parser only supports standard trunk norm, not RMSNorm)"); + trunk.trunk_norm_kind = readInt(); + if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD && + trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) { + throw std::runtime_error(trunk.name + ": unknown/unsupported trunk norm kind " + + std::to_string(trunk.trunk_norm_kind)); } for (int i = 0; i < 5; i++) { int unused = readInt(); @@ -561,14 +727,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) // Parse residual blocks trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels); - trunk.trunk_tip_bn = parseBatchNormLayer(); - trunk.trunk_tip_activation = parseActivationLayer(model_version); - if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { - throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + - std::to_string(trunk.trunk_tip_bn.num_channels) + - ") != trunkNumChannels (" + - std::to_string(trunk.trunk_num_channels) + ")"); + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk.trunk_tip_bn = parseBatchNormLayer(); + if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + + std::to_string(trunk.trunk_tip_bn.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } + } else { + trunk.trunk_tip_rms_norm = parseRMSNormLayer(); + if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" + + std::to_string(trunk.trunk_tip_rms_norm.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } } + trunk.trunk_tip_activation = parseActivationLayer(model_version); return trunk; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index a7d9f161c..09fb9cf84 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -5,8 +5,11 @@ #include "../types/KataGoTypes.hpp" #include +#include #include +#include #include +#include namespace katagocoreml { @@ -31,9 +34,23 @@ class KataGoParser { private: std::string m_model_path; - std::vector m_buffer; - size_t m_pos = 0; + // Custom-deleter unique_ptr owns the gzFile so it closes on every exit path + // (normal return, exception, or bad_alloc) without manual try/catch. + struct GzCloser { + void operator()(gzFile f) const noexcept { if(f) gzclose(f); } + }; + using GzHandle = std::unique_ptr, GzCloser>; + GzHandle m_gz; + std::vector m_refill; // bounded refill buffer (~1 MB) + size_t m_refillPos = 0; // read cursor within m_refill + size_t m_refillLen = 0; // valid bytes in m_refill bool m_binary_floats = true; + bool m_formatDetected = false; + + // Stream primitives + bool refill(); // returns false at EOF + int peekByte(); // -1 at EOF + void readExact(uint8_t* dst, size_t n, const std::string& name); // Low-level reading functions void readUntilWhitespace(std::string& out); @@ -50,11 +67,15 @@ class KataGoParser { ActivationLayerDesc parseActivationLayer(int model_version); MatMulLayerDesc parseMatMulLayer(); MatBiasLayerDesc parseMatBiasLayer(); + TransformerRMSNormDesc parseTransformerRMSNorm(); + RMSNormLayerDesc parseRMSNormLayer(); // Block parsing functions ResidualBlockDesc parseResidualBlock(int model_version); GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version); NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels); + TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version, int trunk_num_channels); + TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version, int trunk_num_channels); std::vector parseBlockStack(int model_version, int num_blocks, int trunk_num_channels); // Component parsing functions @@ -65,9 +86,6 @@ class KataGoParser { // Main model parsing KataGoModelDesc parseModel(); - - // Helper to load file (handles gzip) - void loadFile(); }; } // namespace katagocoreml diff --git a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp index f271f5526..50df8003f 100644 --- a/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/CoreMLSerializer.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace katagocoreml { @@ -230,8 +232,13 @@ void CoreMLSerializer::createPackage(const std::string& output_path, if (!out) { throw std::runtime_error("Failed to create temp model file"); } - if (!model->SerializeToOstream(&out)) { - throw std::runtime_error("Failed to serialize model spec"); + { + google::protobuf::io::OstreamOutputStream zos(&out); + google::protobuf::io::CodedOutputStream cos(&zos); + cos.SetSerializationDeterministic(true); + if (!model->SerializeToCodedStream(&cos)) { + throw std::runtime_error("Failed to serialize model spec"); + } } } diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..69d590609 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,20 +15,25 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { - if (use_fp16) { + const size_t count = entry.data.size(); + // Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not + // declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so + // stored bytes stay consistent with each const's declared dtype. + const bool store_fp16 = use_fp16 && !entry.is_fp32; + if (store_fp16) { // Convert FP32 weights to FP16 - std::vector fp16_data(entry.data.size()); - for (size_t i = 0; i < entry.data.size(); ++i) { + std::vector fp16_data(count); + for (size_t i = 0; i < count; ++i) { fp16_data[i] = MILBlob::Fp16::FromFloat(entry.data[i]); } MILBlob::Util::Span span(fp16_data.data(), fp16_data.size()); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(MILBlob::Fp16); + total_bytes += count * sizeof(MILBlob::Fp16); } else { - // Write FP32 weights - MILBlob::Util::Span span(entry.data.data(), entry.data.size()); + // Write FP32 weights — convert the KataGo-local view to a MILBlob span here. + MILBlob::Util::Span span(entry.data.data(), count); entry.blob_offset = writer.WriteData(span); - total_bytes += entry.data.size() * sizeof(float); + total_bytes += count * sizeof(float); } } diff --git a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp index 147541a39..1074ad419 100644 --- a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp +++ b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp @@ -20,10 +20,15 @@ namespace katagocoreml { enum class ActivationType : int { Identity = 0, ReLU = 1, - Mish = 2 + Mish = 2, + Silu = 3 // MISH_SCALE8 = 12 is internal optimization, treated as Mish }; +/// Trunk normalization kind (matching KataGo's desc.h) +constexpr int TRUNK_NORM_KIND_STANDARD = 0; +constexpr int TRUNK_NORM_KIND_RMSNORM = 1; + // ============================================================================ // Block Kind Constants // ============================================================================ @@ -32,6 +37,8 @@ enum class ActivationType : int { constexpr int ORDINARY_BLOCK_KIND = 0; constexpr int GLOBAL_POOLING_BLOCK_KIND = 2; constexpr int NESTED_BOTTLENECK_BLOCK_KIND = 3; +constexpr int TRANSFORMER_ATTENTION_BLOCK_KIND = 4; +constexpr int TRANSFORMER_FFN_BLOCK_KIND = 5; // ============================================================================ // Layer Descriptors @@ -99,6 +106,25 @@ struct MatBiasLayerDesc { std::vector weights; // Shape: [num_channels] }; +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +struct TransformerRMSNormDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + std::vector weight; // Shape: [num_channels] +}; + +/// Full-featured RMSNorm (gamma/beta, spatial mode) used at the trunk tip. +struct RMSNormLayerDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + bool spatial = false; + int cgroup_size = 0; + std::vector gamma; // Shape: [num_channels] + std::vector beta; // Shape: [num_channels] +}; + // ============================================================================ // Block Descriptors // ============================================================================ @@ -107,12 +133,16 @@ struct MatBiasLayerDesc { struct ResidualBlockDesc; struct GlobalPoolingResidualBlockDesc; struct NestedBottleneckResidualBlockDesc; +struct TransformerAttentionBlockDesc; +struct TransformerFFNBlockDesc; /// Block descriptor variant using BlockDesc = std::variant< ResidualBlockDesc, GlobalPoolingResidualBlockDesc, - NestedBottleneckResidualBlockDesc + NestedBottleneckResidualBlockDesc, + TransformerAttentionBlockDesc, + TransformerFFNBlockDesc >; /// Block with its kind @@ -166,6 +196,38 @@ struct NestedBottleneckResidualBlockDesc { ConvLayerDesc post_conv; }; +/// Transformer self-attention block descriptor (pre-norm, multi-head, optional 2D RoPE, GQA). +struct TransformerAttentionBlockDesc { + std::string name; + int num_heads = 0; + int num_kv_heads = 0; + int q_head_dim = 0; + int v_head_dim = 0; + bool use_rope = false; + bool learnable_rope = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc q_proj; + MatMulLayerDesc k_proj; + MatMulLayerDesc v_proj; + MatMulLayerDesc out_proj; + int rope_num_kv_heads = 0; + int rope_num_pairs = 0; + std::vector rope_freqs; // learnable: (num_kv_heads, num_pairs, 2) flattened + float rope_theta = 0.0f; +}; + +/// Transformer feed-forward (SwiGLU) block descriptor. +struct TransformerFFNBlockDesc { + std::string name; + int num_channels = 0; + int ffn_channels = 0; + bool use_swiglu = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc linear1; + MatMulLayerDesc linear_gate; // only used when use_swiglu + MatMulLayerDesc linear2; +}; + // ============================================================================ // SGF Metadata Encoder (v15+) // ============================================================================ @@ -203,7 +265,9 @@ struct TrunkDesc { MatMulLayerDesc initial_matmul; std::optional sgf_metadata_encoder; std::vector blocks; + int trunk_norm_kind = TRUNK_NORM_KIND_STANDARD; BatchNormLayerDesc trunk_tip_bn; + RMSNormLayerDesc trunk_tip_rms_norm; ActivationLayerDesc trunk_tip_activation; }; diff --git a/cpp/main.cpp b/cpp/main.cpp index 8e4cba8f3..60e6bd78b 100644 --- a/cpp/main.cpp +++ b/cpp/main.cpp @@ -36,7 +36,7 @@ match : Run self-play match games based on a config, more efficient than gtp due version : Print version and exit. analysis : Runs an engine designed to analyze entire games in parallel. -tuner : (OpenCL only) Run tuning to find and optimize parameters that work on your GPU. +tuner : (OpenCL and MLX) Run tuning to find and optimize parameters that work on your GPU. ---Selfplay training subcommands--------- @@ -246,6 +246,8 @@ string Version::getKataGoVersionFullInfo() { out << "Using OpenCL backend" << endl; #elif defined(USE_EIGEN_BACKEND) out << "Using Eigen(CPU) backend" << endl; +#elif defined(USE_MLX_BACKEND) + out << "Using MLX backend" << endl; #else out << "Using dummy backend" << endl; #endif @@ -282,6 +284,8 @@ string Version::getGitRevisionWithBackend() { s += "-opencl"; #elif defined(USE_EIGEN_BACKEND) s += "-eigen"; +#elif defined(USE_MLX_BACKEND) + s += "-mlx"; #else s += "-dummy"; #endif diff --git a/cpp/neuralnet/desc.cpp b/cpp/neuralnet/desc.cpp index 8141b4366..72e01238d 100644 --- a/cpp/neuralnet/desc.cpp +++ b/cpp/neuralnet/desc.cpp @@ -197,6 +197,10 @@ void ConvLayerDesc::scaleOutputChannels(const std::vector& scaling) { } } +void ConvLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- BatchNormLayerDesc::BatchNormLayerDesc() : numChannels(0), epsilon(0.001f), hasScale(false), hasBias(false) {} @@ -389,6 +393,15 @@ ActivationLayerDesc::ActivationLayerDesc(istream& in, int modelVersion) { } } +void BatchNormLayerDesc::releaseWeights() { + std::vector().swap(mean); + std::vector().swap(variance); + std::vector().swap(scale); + std::vector().swap(bias); + std::vector().swap(mergedScale); + std::vector().swap(mergedBias); +} + ActivationLayerDesc::ActivationLayerDesc(ActivationLayerDesc&& other) { *this = std::move(other); } @@ -517,6 +530,10 @@ MatBiasLayerDesc::MatBiasLayerDesc(istream& in, bool binaryFloats) { throw StringError(name + ": matbiaslayer failed to parse expected number of matbias weights"); } +void MatMulLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + MatBiasLayerDesc::MatBiasLayerDesc(MatBiasLayerDesc&& other) { *this = std::move(other); } @@ -538,6 +555,10 @@ void MatBiasLayerDesc::applyScale8ToReduceActivations() { } } +void MatBiasLayerDesc::releaseWeights() { + std::vector().swap(weights); +} + //----------------------------------------------------------------------------- ResidualBlockDesc::ResidualBlockDesc() {} @@ -617,6 +638,13 @@ void ResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void ResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- GlobalPoolingResidualBlockDesc::GlobalPoolingResidualBlockDesc() {} @@ -738,6 +766,16 @@ void GlobalPoolingResidualBlockDesc::applyScale8ToReduceActivations() { midActivation.applyScale8ToReduceActivations(); } +void GlobalPoolingResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + regularConv.releaseWeights(); + gpoolConv.releaseWeights(); + gpoolBN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + midBN.releaseWeights(); + finalConv.releaseWeights(); +} + //----------------------------------------------------------------------------- NestedBottleneckResidualBlockDesc::NestedBottleneckResidualBlockDesc() {} @@ -992,6 +1030,38 @@ void NestedBottleneckResidualBlockDesc::applyScale8ToReduceActivations() { postActivation.applyScale8ToReduceActivations(); } +void NestedBottleneckResidualBlockDesc::releaseWeights() { + preBN.releaseWeights(); + preConv.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_ATTENTION_BLOCK_KIND) { + TransformerAttentionDesc* desc = (TransformerAttentionDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_FFN_BLOCK_KIND) { + TransformerFFNDesc* desc = (TransformerFFNDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + postBN.releaseWeights(); + postConv.releaseWeights(); +} + //----------------------------------------------------------------------------- RMSNormLayerDesc::RMSNormLayerDesc() : numChannels(0), epsilon(0), spatial(false), cgroupSize(0) {} @@ -1043,6 +1113,11 @@ int64_t RMSNormLayerDesc::getNumParameters() const { return (int64_t)gamma.size() + (int64_t)beta.size(); } +void RMSNormLayerDesc::releaseWeights() { + std::vector().swap(gamma); + std::vector().swap(beta); +} + //----------------------------------------------------------------------------- TransformerRMSNormDesc::TransformerRMSNormDesc() : numChannels(0), epsilon(0) {} @@ -1083,6 +1158,10 @@ int64_t TransformerRMSNormDesc::getNumParameters() const { return (int64_t)weight.size(); } +void TransformerRMSNormDesc::releaseWeights() { + std::vector().swap(weight); +} + //----------------------------------------------------------------------------- TransformerAttentionDesc::TransformerAttentionDesc() @@ -1209,6 +1288,15 @@ int64_t TransformerAttentionDesc::getNumParameters() const { (int64_t)ropeFreqs.size(); // learnable RoPE frequencies, empty for fixed/no RoPE } +void TransformerAttentionDesc::releaseWeights() { + preLN.releaseWeights(); + qProj.releaseWeights(); + kProj.releaseWeights(); + vProj.releaseWeights(); + outProj.releaseWeights(); + std::vector().swap(ropeFreqs); +} + void TransformerAttentionDesc::computeRopeCosSin(int nnXLen, int nnYLen, int paddedNNXYLen, std::vector& cosTable, std::vector& sinTable) const { if(!useRope) throw StringError("TransformerAttentionDesc::computeRopeCosSin called but useRope is false"); @@ -1344,6 +1432,13 @@ int64_t TransformerFFNDesc::getNumParameters() const { linear2.getNumParameters(); } +void TransformerFFNDesc::releaseWeights() { + preLN.releaseWeights(); + linear1.releaseWeights(); + linearGate.releaseWeights(); + linear2.releaseWeights(); +} + //----------------------------------------------------------------------------- static void parseResidualBlockStack( @@ -1550,6 +1645,14 @@ int64_t SGFMetadataEncoderDesc::getNumParameters() const { mul3.getNumParameters(); } +void SGFMetadataEncoderDesc::releaseWeights() { + mul1.releaseWeights(); + bias1.releaseWeights(); + mul2.releaseWeights(); + bias2.releaseWeights(); + mul3.releaseWeights(); +} + //----------------------------------------------------------------------------- TrunkDesc::TrunkDesc() @@ -1906,6 +2009,40 @@ void TrunkDesc::applyScale8ToReduceActivations() { } } +void TrunkDesc::releaseWeights() { + initialConv.releaseWeights(); + initialMatMul.releaseWeights(); + if(metaEncoderVersion > 0) + sgfMetadataEncoder.releaseWeights(); + for(int i = 0; i < blocks.size(); i++) { + if(blocks[i].first == ORDINARY_BLOCK_KIND) { + ResidualBlockDesc* desc = (ResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == GLOBAL_POOLING_BLOCK_KIND) { + GlobalPoolingResidualBlockDesc* desc = (GlobalPoolingResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == NESTED_BOTTLENECK_BLOCK_KIND) { + NestedBottleneckResidualBlockDesc* desc = (NestedBottleneckResidualBlockDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_ATTENTION_BLOCK_KIND) { + TransformerAttentionDesc* desc = (TransformerAttentionDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else if(blocks[i].first == TRANSFORMER_FFN_BLOCK_KIND) { + TransformerFFNDesc* desc = (TransformerFFNDesc*)blocks[i].second.get(); + desc->releaseWeights(); + } + else { + ASSERT_UNREACHABLE; + } + } + // Whichever trunk tip norm is unused has empty parameter vectors, so releasing both is safe. + trunkTipBN.releaseWeights(); + trunkTipRMSNorm.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -2086,6 +2223,18 @@ void PolicyHeadDesc::applyScale8ToReduceActivations() { passActivation.applyScale8ToReduceActivations(); } +void PolicyHeadDesc::releaseWeights() { + p1Conv.releaseWeights(); + g1Conv.releaseWeights(); + g1BN.releaseWeights(); + gpoolToBiasMul.releaseWeights(); + p1BN.releaseWeights(); + p2Conv.releaseWeights(); + gpoolToPassMul.releaseWeights(); + gpoolToPassBias.releaseWeights(); + gpoolToPassMul2.releaseWeights(); +} + //----------------------------------------------------------------------------- ValueHeadDesc::ValueHeadDesc() : modelVersion(-1) {} @@ -2246,6 +2395,17 @@ void ValueHeadDesc::applyScale8ToReduceActivations() { sv3Bias.applyScale8ToReduceActivations(); } +void ValueHeadDesc::releaseWeights() { + v1Conv.releaseWeights(); + v1BN.releaseWeights(); + v2Mul.releaseWeights(); + v2Bias.releaseWeights(); + v3Mul.releaseWeights(); + v3Bias.releaseWeights(); + sv3Mul.releaseWeights(); + sv3Bias.releaseWeights(); + vOwnershipConv.releaseWeights(); +} //----------------------------------------------------------------------------- @@ -2562,6 +2722,12 @@ void ModelDesc::applyScale8ToReduceActivations() { postProcessParams.outputScaleMultiplier *= 8.0f; } +void ModelDesc::releaseWeights() { + trunk.releaseWeights(); + policyHead.releaseWeights(); + valueHead.releaseWeights(); +} + struct NonCopyingStreamBuf : public std::streambuf { NonCopyingStreamBuf(string& str) { diff --git a/cpp/neuralnet/desc.h b/cpp/neuralnet/desc.h index 36c5a11d8..ef41dfca6 100644 --- a/cpp/neuralnet/desc.h +++ b/cpp/neuralnet/desc.h @@ -36,6 +36,8 @@ struct ConvLayerDesc { int64_t getNumParameters() const; void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct BatchNormLayerDesc { @@ -68,6 +70,8 @@ struct BatchNormLayerDesc { void extractChannelFactorsAbsLtOne(std::vector& channelFactors); void extractChannelFactorsAbsLtOneWithInverses(std::vector& channelFactors, std::vector& invChannelFactors); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ActivationLayerDesc { @@ -105,6 +109,8 @@ struct MatMulLayerDesc { int64_t getNumParameters() const; void scaleOutputChannels(const std::vector& scaling); + + void releaseWeights(); }; struct MatBiasLayerDesc { @@ -124,6 +130,8 @@ struct MatBiasLayerDesc { int64_t getNumParameters() const; void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ResidualBlockDesc { @@ -150,6 +158,8 @@ struct ResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct GlobalPoolingResidualBlockDesc { @@ -181,6 +191,8 @@ struct GlobalPoolingResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct NestedBottleneckResidualBlockDesc { @@ -215,6 +227,8 @@ struct NestedBottleneckResidualBlockDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; // Trunk final normalization kind (stored in trunk header) @@ -240,6 +254,7 @@ struct RMSNormLayerDesc { RMSNormLayerDesc& operator=(RMSNormLayerDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; // Lightweight RMSNorm used inside transformer blocks (weight only, no bias, no spatial modes) @@ -259,6 +274,7 @@ struct TransformerRMSNormDesc { TransformerRMSNormDesc& operator=(TransformerRMSNormDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; struct TransformerAttentionDesc { @@ -294,6 +310,7 @@ struct TransformerAttentionDesc { TransformerAttentionDesc& operator=(TransformerAttentionDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); // Compute cos/sin tables for RoPE given board dimensions. // Output tables are indexed as: @@ -324,6 +341,7 @@ struct TransformerFFNDesc { TransformerFFNDesc& operator=(TransformerFFNDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; struct SGFMetadataEncoderDesc { @@ -349,6 +367,7 @@ struct SGFMetadataEncoderDesc { SGFMetadataEncoderDesc& operator=(SGFMetadataEncoderDesc&& other); int64_t getNumParameters() const; + void releaseWeights(); }; @@ -397,6 +416,8 @@ struct TrunkDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct PolicyHeadDesc { @@ -431,6 +452,8 @@ struct PolicyHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ValueHeadDesc { @@ -463,6 +486,8 @@ struct ValueHeadDesc { void transformToReduceActivations(); void applyScale8ToReduceActivations(); + + void releaseWeights(); }; struct ModelPostProcessParams { @@ -534,6 +559,11 @@ struct ModelDesc { //Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made. Rules getSupportedRules(const Rules& desiredRules, bool& supported) const; + // Frees all weight arrays (conv/matmul/bias/batchnorm), keeping scalar shape + // metadata intact. Safe once weights are no longer needed (e.g. CoreML/ANE + // inference, which reads weights from the compiled .mlmodelc). + void releaseWeights(); + }; #endif // #ifndef DESC_H diff --git a/cpp/neuralnet/greedysearch.h b/cpp/neuralnet/greedysearch.h new file mode 100644 index 000000000..65bf59825 --- /dev/null +++ b/cpp/neuralnet/greedysearch.h @@ -0,0 +1,75 @@ +#ifndef NEURALNET_GREEDYSEARCH_H_ +#define NEURALNET_GREEDYSEARCH_H_ + +// Pure, header-only greedy coordinate descent over discrete axes. No MLX/Metal +// dependency, so it is unit-tested standalone. Axes are index-based: each axis +// has a fixed number of candidate value-indices [0, size); the caller maps an +// index assignment to a concrete config inside its score callback. Lower score +// is better; the callback returns +inf for an invalid assignment. + +#include +#include +#include +#include + +namespace GreedySearch { + +struct Result { + std::vector indices; // best value-index per axis + double score; // its score + int evaluated; // number of scoreFn calls (instrumentation/tests) +}; + +// axisSizes[a] = number of candidate values for axis a. +// order = axis indices, highest-sensitivity first (a permutation of [0,nAxes)). +// seedIndices = starting index per axis; MUST score finite (it is the always-valid floor). +// scoreFn(idx) = lower is better; return +inf for invalid assignments. +// maxPasses = pass cap; descent also stops early on a no-change pass. +inline Result coordinateDescent( + const std::vector& axisSizes, + const std::vector& order, + const std::vector& seedIndices, + const std::function&)>& scoreFn, + int maxPasses) { + const size_t nAxes = axisSizes.size(); + assert(seedIndices.size() == nAxes); + assert(order.size() == nAxes); + +#ifndef NDEBUG + { + std::vector seen(nAxes, 0); + for(int a : order) { + assert(a >= 0 && (size_t)a < nAxes); + assert(!seen[a]); + seen[a] = 1; + } + } +#endif + + std::vector best = seedIndices; + double bestScore = scoreFn(best); + int evaluated = 1; + + for(int pass = 0; pass < maxPasses; pass++) { + bool changed = false; + for(int axis : order) { + const int curVal = best[axis]; + int bestVal = curVal; + for(int v = 0; v < axisSizes[axis]; v++) { + if(v == curVal) continue; // current value's score is already bestScore + std::vector trial = best; + trial[axis] = v; + const double s = scoreFn(trial); + evaluated++; + if(s < bestScore) { bestScore = s; bestVal = v; } + } + if(bestVal != curVal) { best[axis] = bestVal; changed = true; } + } + if(!changed) break; + } + return Result{best, bestScore, evaluated}; +} + +} // namespace GreedySearch + +#endif // NEURALNET_GREEDYSEARCH_H_ diff --git a/cpp/neuralnet/greedysearch_test.cpp b/cpp/neuralnet/greedysearch_test.cpp new file mode 100644 index 000000000..f1de21838 --- /dev/null +++ b/cpp/neuralnet/greedysearch_test.cpp @@ -0,0 +1,73 @@ +// Standalone unit test for the pure greedy coordinate-descent core. +// Build & run (no Xcode needed): +// clang++ -std=c++20 -I cpp cpp/neuralnet/greedysearch_test.cpp -o /tmp/greedysearch_test && /tmp/greedysearch_test +#include "neuralnet/greedysearch.h" +#include +#include +#include +#include + +using std::vector; + +static int failures = 0; +#define CHECK(cond) do { if(!(cond)) { std::printf("FAIL %s:%d %s\n", __FILE__, __LINE__, #cond); failures++; } } while(0) + +int main() { + // Axes: 3 axes of sizes 4,4,3. Separable score with a planted optimum at + // indices (3,0,2): score = |i0-3| + |i1-0| + |i2-2|. Coordinate descent on a + // separable convex score must reach the exact optimum (score 0). + { + vector sizes = {4,4,3}; + vector order = {0,1,2}; + vector seed = {0,0,0}; + int target0=3, target1=0, target2=2; + auto score = [&](const vector& idx)->double { + return std::abs(idx[0]-target0) + std::abs(idx[1]-target1) + std::abs(idx[2]-target2); + }; + GreedySearch::Result r = GreedySearch::coordinateDescent(sizes, order, seed, score, 3); + CHECK(r.indices == (vector{3,0,2})); + CHECK(r.score == 0.0); + CHECK(r.evaluated >= 1); + } + + // Invalid combos (score +inf) are never selected and never crash. + { + vector sizes = {3,3}; + vector order = {0,1}; + vector seed = {0,0}; + auto score = [&](const vector& idx)->double { + if(idx[0]==2 && idx[1]==2) return std::numeric_limits::infinity(); // forbidden + return (idx[0]==2 ? 0.0 : 1.0) + (idx[1]==2 ? 0.0 : 1.0); // wants (2,2) but it's invalid + }; + GreedySearch::Result r = GreedySearch::coordinateDescent(sizes, order, seed, score, 3); + CHECK(!(r.indices[0]==2 && r.indices[1]==2)); + CHECK(std::isfinite(r.score)); + } + + // Deterministic: identical inputs → identical result. + { + vector sizes = {4,3,2}; + vector order = {2,0,1}; + vector seed = {1,1,0}; + auto score = [&](const vector& idx)->double { return (idx[0]-2)*(idx[0]-2) + idx[1] + (1-idx[2]); }; + GreedySearch::Result a = GreedySearch::coordinateDescent(sizes, order, seed, score, 3); + GreedySearch::Result b = GreedySearch::coordinateDescent(sizes, order, seed, score, 3); + CHECK(a.indices == b.indices); + CHECK(a.score == b.score); + } + + // Constant score → no axis improves → returns the seed; evaluations bounded. + { + vector sizes = {4,4,3}; + vector order = {0,1,2}; + vector seed = {2,3,1}; + auto score = [&](const vector&)->double { return 7.0; }; + GreedySearch::Result r = GreedySearch::coordinateDescent(sizes, order, seed, score, 3); + CHECK(r.indices == seed); + // 1 seed eval + one pass of (sizes-1) probes, then a no-change pass stops it. + CHECK(r.evaluated <= 1 + 3*((4-1)+(4-1)+(3-1))); + } + + if(failures==0) std::printf("ALL GREEDY TESTS PASSED\n"); + return failures==0 ? 0 : 1; +} diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 786ef8290..6471e74c3 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -426,13 +426,27 @@ ComputeContext* NeuralNet::createComputeContext( const LoadedModel* loadedModel, ConfigParser& cfg) { - (void)gpuIdxs; + // Only ANE-only configurations may free the engine's in-memory weights: the + // GPU/MPSGraph path reads them via modelDescToSwift, so freeing is unsafe + // unless no GPU handle can ever be built from this model. + // INVARIANT: gpuIdxs must be the complete (deduplicated) set of device indices + // that will ever be passed as gpuIdxForThisThread to createComputeHandle for + // this context. aneOnly==true frees the in-memory weights, so if any thread + // later used a GPU (MPSGraph) index not represented here, it would read freed + // weights. KataGo derives both from the same gpuIdxByServerThread list, so the + // invariant holds today; preserve it if that wiring ever changes. + bool aneOnly = !gpuIdxs.empty(); + for(int idx : gpuIdxs) { + if(idx != METAL_MUX_ANE) { aneOnly = false; break; } + } (void)logger; (void)homeDataDirOverride; (void)loadedModel; (void)cfg; - return new ComputeContext(nnXLen, nnYLen, useFP16Mode); + ComputeContext* context = new ComputeContext(nnXLen, nnYLen, useFP16Mode); + context->aneOnly = aneOnly; + return context; } void NeuralNet::freeComputeContext(ComputeContext* computeContext) { @@ -459,6 +473,17 @@ static swift::Optional convertAndCreateCoreMLO bool useFP16 = (context->useFP16Mode != enabled_t::False); bool optimizeMask = requireExactNNLen; + // On a confirmed ANE-only run, free the engine's in-memory ModelDesc weight + // arrays. This function converts from loadedModel->modelPath (disk), + // so the in-memory weights are not read here; the GPU/MPSGraph path (which + // DOES read them via modelDescToSwift) is never built when aneOnly is true. + // The whole ComputeHandle ctor runs under computeHandleMutex, so this is not + // racy; releaseWeights() clears only weight vectors, leaving the scalar dims + // read by the ComputeHandle ctor / InputBuffers valid. + if(context->aneOnly) { + const_cast(loadedModel)->modelDesc.releaseWeights(); + } + // Convert model to CoreML format in temp directory string coremlModelPath = CoreMLConversion::convertModelToTemp( loadedModel->modelPath, diff --git a/cpp/neuralnet/metalbackend.h b/cpp/neuralnet/metalbackend.h index b7f751e63..e77dd18d9 100644 --- a/cpp/neuralnet/metalbackend.h +++ b/cpp/neuralnet/metalbackend.h @@ -113,6 +113,14 @@ struct ComputeContext { */ MetalComputeContext metalContext; + /** + * @brief True only when every configured device is METAL_MUX_ANE, so no + * MPSGraph (GPU) handle will ever read modelDesc weights. Gates the call to + * ModelDesc::releaseWeights() so a mixed GPU+ANE config can never free live + * weights. + */ + bool aneOnly = false; + /** * @brief Constructs a ComputeContext object. * @param nnX The width of the input tensor. @@ -179,6 +187,18 @@ struct ComputeHandle { */ bool maskIdentityChecked = false; + // Weight-release safety is guaranteed by ComputeContext::aneOnly, NOT by the + // declaration order below: within a single ComputeHandle exactly one handle is + // built (the two paths are mutually exclusive on gpuIdx, enforced by the + // ctor's exactly-one check), and releaseWeights() only ever fires on an + // aneOnly context, where no MPSGraph handle is built for any thread. + // That said, keep mpsGraphOnlyHandle declared before coremlOnlyHandle. C++ + // initializes members in DECLARATION order, so createMPSGraphHandleIfNeeded + // (which reads modelDesc weights via modelDescToSwift) is sequenced before + // createCoreMLOnlyHandleIfNeeded (which may call modelDesc.releaseWeights()). + // This ordering is belt-and-suspenders that preserves the natural read-then- + // release sequence should the aneOnly invariant ever be weakened; don't rely + // on it as the primary guarantee, but don't reorder it either. /** * @brief The MPSGraph-only handle instance from Swift (GPU-only mode). */ diff --git a/cpp/neuralnet/mlxbackend.cpp b/cpp/neuralnet/mlxbackend.cpp new file mode 100644 index 000000000..4092a2ac5 --- /dev/null +++ b/cpp/neuralnet/mlxbackend.cpp @@ -0,0 +1,3166 @@ +#ifdef USE_MLX_BACKEND + +/** + * MLX backend for KataGo. + * Uses Apple's MLX framework for neural network inference on Apple Silicon. + * Supports FP16 (half precision) and FP32 computation with NHWC memory layout. + * FP16 Winograd uses selective fp32 accumulation at the matmul reduction and + * BatchNorm intermediate for numerical stability. + * `mlxUseFP16 = auto` resolves to fp16. + */ + +#include "../neuralnet/nninterface.h" +#include "../neuralnet/desc.h" +#include "../neuralnet/modelversion.h" +#include "../neuralnet/nninputs.h" +#include "../neuralnet/nneval.h" +#include "../neuralnet/activations.h" +#include "../neuralnet/mlxwinograd.h" +#include "../neuralnet/mlxwinotuner.h" +#include "../core/global.h" +#include "../core/test.h" + +#include +#include +#include +#include +#include +#include // For getpid() +#include +#include +#include // malloc / std::getenv +#include +#include +#include +#include +#include +#include +#include +#include + +// Test-only free functions, defined in mlxtests.cpp. Invoked once per +// process from testEvaluateConv via the ranMLXAuxTests guard. +void runMLXWinogradTests(); +void runMLXWinotunerTests(); + +namespace mx = mlx::core; + +// Type alias for compiled inference functions +using CompiledInferenceFunc = std::function(const std::vector&)>; + +// Cache key: (batchSize, nnXLen, nnYLen, useMask, hasMeta, useFP16) +using CompileCacheKey = std::tuple; +using namespace std; + +// MUX modes: gpuIdx selects per-thread execution path. +// Same convention the Metal backend uses (METAL_MUX_GPU / METAL_MUX_ANE). +static constexpr int MLX_MUX_GPU = 0; // MLX/GPU - default +static constexpr int MLX_MUX_ANE = 100; // CoreML on CPU+ANE via katagocoreml + KataGoSwift + +// Serializes ComputeHandle construction across server threads. The CoreML +// converter (katagocoreml::KataGoConverter::convert) holds process-global +// MIL writer state that is not reentrant; without this lock, 2+ ANE threads +// racing at startup corrupt the .mlpackage and throw "Metadata written to +// different offset than expected." Mirrors metalbackend.cpp's +// computeHandleMutex. +static std::mutex computeHandleMutex; + +// Serializes MLX/GPU graph construction and command encoding across NN server +// threads. MLX's GPU streams have no per-stream worker thread (scheduler.h +// pushes nullptr for gpu streams); encoding runs inline on the calling thread +// and every handle shares MLX's single global default GPU stream. Two server +// threads encoding concurrently therefore open two compute command encoders on +// the same MTLCommandBuffer, which aborts with the Metal assertion "A command +// encoder is already encoding to this command buffer". So applyCompiled holds +// this lock across input-array creation, the compiled-graph call, and +// mx::async_eval (which performs all Metal command encoding synchronously on +// the calling thread before returning). The wait for GPU completion +// (array::wait, a per-array shared-event block) and the result readback +// (data() on already-materialized buffers) do not encode, so they run +// OUTSIDE the lock: with numNNServerThreadsPerModel=2, one server thread +// encodes its batch while the other waits on the GPU, keeping the GPU fed +// back-to-back instead of idling during encode + readback. +static std::mutex mlxGpuEvalMutex; + +//------------------------------------------------------------------------------ +// CoreML Model Conversion - reuses katagocoreml library, mirrors metalbackend.cpp +//------------------------------------------------------------------------------ + +namespace gfs = ghc::filesystem; + +namespace CoreMLConversion { + +// Get temp directory for model conversion. Identical path to Metal's +// getTempDirectory() in metalbackend.cpp so a .mlpackage produced by either +// backend can be reused by the other on a same-model run. +static string getTempDirectory() { + gfs::path tempDir = gfs::temp_directory_path() / "katago_coreml"; + std::error_code ec; + gfs::create_directories(tempDir, ec); + if(ec) { + throw runtime_error("Failed to create temp directory: " + ec.message()); + } + return tempDir.string(); +} + +// Generate unique temporary path for model conversion +static string generateTempPath(int serverThreadIdx) { + auto now = chrono::steady_clock::now().time_since_epoch().count(); + return getTempDirectory() + "/model_" + to_string(getpid()) + "_" + + to_string(serverThreadIdx) + "_" + to_string(now) + ".mlpackage"; +} + +// CoreML model metadata constants +static const string COREML_MODEL_AUTHOR = "KataGo"; +static const string COREML_MODEL_LICENSE = "See original model file for license terms"; + +// Convert KataGo model to CoreML in temp directory, returns path to .mlpackage. +// The caller (Swift side) is responsible for deleting the temp file after loading: +// see deleteSourceModel in metalbackend.swift, invoked via `defer` from +// createCoreMLComputeHandle. +static string convertModelToTemp( + const string& modelPath, + int boardX, + int boardY, + bool useFP16, + bool optimizeMask, + int maxBatchSize, + int serverThreadIdx +) { + // maxBatchSize is validated upstream: cfg.getInt("nnMaxBatchSize", 1, 65536) in setup.cpp + // and NNEvaluator constructor throws if maxBatchSize <= 0. Assert for defensive documentation. + assert(maxBatchSize >= 1); + + string tempPath = generateTempPath(serverThreadIdx); + cerr << "MLX backend " << serverThreadIdx << ": Converting model to " << tempPath << endl; + + katagocoreml::ConversionOptions opts; + opts.board_x_size = boardX; + opts.board_y_size = boardY; + opts.compute_precision = useFP16 ? "FLOAT16" : "FLOAT32"; + opts.optimize_identity_mask = optimizeMask; + opts.min_batch_size = 1; + opts.max_batch_size = maxBatchSize; + opts.author = COREML_MODEL_AUTHOR; + opts.license = COREML_MODEL_LICENSE; + + try { + katagocoreml::KataGoConverter::convert(modelPath, tempPath, opts); + } catch(const exception& e) { + // Clean up partial conversion on failure + std::error_code ec; + gfs::remove_all(tempPath, ec); + if(ec) { + cerr << "MLX backend " << serverThreadIdx << ": Warning: Failed to clean up partial conversion at " << tempPath << ": " << ec.message() << endl; + } + throw runtime_error(string("MLX backend ") + to_string(serverThreadIdx) + ": Core ML model conversion failed: " + e.what()); + } + + cerr << "MLX backend " << serverThreadIdx << ": Conversion completed" << endl; + return tempPath; +} + +} // namespace CoreMLConversion + +// LoadedModel / ModelDesc --------------------------------------------------------------------------------------------- + +struct LoadedModel { + ModelDesc modelDesc; + // Source path of the .bin.gz, retained for CoreML/ANE mux: the katagocoreml + // converter needs the on-disk source to produce a .mlpackage. The MLX GPU + // path does not read this field. + string modelPath; + + LoadedModel(const string& fileName, const string& expectedSha256) { + ModelDesc::loadFromFileMaybeGZipped(fileName, modelDesc, expectedSha256); + modelPath = fileName; + } + + LoadedModel() = delete; + LoadedModel(const LoadedModel&) = delete; + LoadedModel& operator=(const LoadedModel&) = delete; +}; + +LoadedModel* NeuralNet::loadModelFile(const string& file, const string& expectedSha256) { + LoadedModel* loadedModel = new LoadedModel(file, expectedSha256); + return loadedModel; +} + +void NeuralNet::freeLoadedModel(LoadedModel* loadedModel) { + delete loadedModel; +} + +const ModelDesc& NeuralNet::getModelDesc(const LoadedModel* loadedModel) { + return loadedModel->modelDesc; +} + +// Helpers -------------------------------------------------------------------------------------------------------------- + +// Convert convolution weights from OIHW to OHWI (MLX conv2d weight format) +static mx::array convertConvWeightsOIHWtoOHWI(const vector& weights, + int outChannels, int inChannels, + int kH, int kW) { + // Original: [outC, inC, kH, kW] - stored in column-major order + // Target: [outC, kH, kW, inC] + vector converted(weights.size()); + for (int oc = 0; oc < outChannels; oc++) { + for (int ic = 0; ic < inChannels; ic++) { + for (int h = 0; h < kH; h++) { + for (int w = 0; w < kW; w++) { + int srcIdx = ((oc * inChannels + ic) * kH + h) * kW + w; + int dstIdx = ((oc * kH + h) * kW + w) * inChannels + ic; + converted[dstIdx] = weights[srcIdx]; + } + } + } + } + mx::Shape shape = {outChannels, kH, kW, inChannels}; + return mx::array(converted.data(), shape, mx::float32); +} + +// Convert array to compute dtype. Lazy form for the inference hot path +// (each call's astype goes into the compiled trace; evaluating eagerly +// would force a stream sync per inference). +static mx::array toComputeDtype(const mx::array& arr, bool useFP16) { + return useFP16 ? mx::astype(arr, mx::float16) : arr; +} + +// Convert array to compute dtype and materialize the result. +// +// Use this for STATIC layer weights cached on a shared Model (the +// `cachedModels` map below shares a single Model instance across all +// MLX/GPU server threads). Without the eval, fp16 weights are +// unevaluated AsType primitives stamped with the constructor thread's +// MLX Stream; any other thread that later evals a compiled graph that +// captures these weights throws "There is no Stream(gpu, N) in current +// thread." with N = the constructor thread's stream index. MLX +// 0.31.2's command encoders live in `thread_local` storage inside +// mlx-core's metal/device.cpp, so a stream created on thread A is +// unreachable from thread B. +static mx::array toComputeDtypeMaterialized(const mx::array& arr, bool useFP16) { + if(!useFP16) return arr; + mx::array result = mx::astype(arr, mx::float16); + mx::eval(result); + return result; +} + +// Mish activation: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x))) +// +// Numerical stability: softplus is computed via logaddexp(0, x), which MLX +// implements as max(0, x) + log1p(exp(-|x|)) (see mlx/backend/cpu/binary_ops.h +// LogAddExp). The exp argument is always in (-inf, 0], so exp(-|x|) lies in +// (0, 1] and cannot overflow in either FP32 or FP16. This is why MLX does +// not need the ACTIVATION_MISH_SCALE8 variant that CUDA/OpenCL/TensorRT apply +// at model load (each backend calls modelDesc.applyScale8ToReduceActivations, +// implemented in desc.cpp) to keep Mish inside FP16 +// representable range: those backends compute softplus via a path that +// overflows for x >~ 11 in FP16 (since exp(11.09) >~ 65504 = FP16 max). +// Cross-backend validation against an Eigen FP32 reference confirms FP16 +// MLX is within typical half-precision tolerance with no Mish-overflow +// artifacts (see testgpuerror workflow in CLAUDE.md). +static mx::array applyMish(const mx::array& x) { + // softplus(x) = log(1 + exp(x)) = log(exp(0) + exp(x)) = logaddexp(0, x). + // MLX's logaddexp uses max(0,x) + log1p(exp(-|x|)) -- overflow-free. + mx::array softplus = mx::logaddexp(mx::array(0.0f), x); + return x * mx::tanh(softplus); +} + +// Apply activation function +static mx::array applyActivation(const mx::array& x, int activationType) { + switch(activationType) { + case ACTIVATION_RELU: + return mx::maximum(x, mx::array(0.0f)); + case ACTIVATION_MISH: + return applyMish(x); + case ACTIVATION_SILU: + // SiLU (swish): x * sigmoid(x) = x / (1 + exp(-x)). Matches Eigen's + // ACTIVATION_SILU (x / ((-x).exp() + 1)). MLX's sigmoid is overflow-safe. + return x * mx::sigmoid(x); + case ACTIVATION_MISH_SCALE8: + // ACTIVATION_MISH_SCALE8 is an FP16-numerics workaround applied in-place + // at model load by CUDA/OpenCL/TensorRT (see desc.cpp:applyScale8To- + // ReduceActivations). MLX does not call that transform because its + // logaddexp-based softplus is already overflow-free in FP16 (see + // applyMish above), so we should never see this enum here. If a model + // ever ships with MISH_SCALE8 baked in on disk, fail loudly rather than + // silently fall through to identity. Mirrors Eigen/Metal behavior. + testAssert(false); + return x; // unreached; satisfies compiler + case ACTIVATION_IDENTITY: + default: + return x; + } +} + +// Fused matmul + bias: result = input @ weights + bias +// Uses addmm for better performance (single kernel instead of matmul + add) +static mx::array matmulBias(const mx::array& input, const mx::array& weights, const mx::array& bias) { + // addmm(c, a, b, alpha, beta) = alpha * (a @ b) + beta * c + return mx::addmm(bias, input, weights, 1.0f, 1.0f); +} + +// Winograd is on by default; KATAGO_MLX_WINOGRAD=0 forces mx::conv2d +// (A/B correctness testing and runtime safety valve). +static bool mlxWinogradEnabled() { + static const bool enabled = [](){ + const char* e = std::getenv("KATAGO_MLX_WINOGRAD"); + return !(e != nullptr && std::string(e) == "0"); + }(); + return enabled; +} + +// Tuner is on by default; KATAGO_MLX_WINOTUNER=0 forces baked defaults. +static bool mlxWinotunerEnabled() { + static const bool enabled = [](){ + const char* e = std::getenv("KATAGO_MLX_WINOTUNER"); + return !(e != nullptr && std::string(e) == "0"); + }(); + return enabled; +} +// Layers -------------------------------------------------------------------------------------------------------------- + +struct ConvLayer { + const string name; + const int convYSize; + const int convXSize; + const int inChannels; + const int outChannels; + const int dilationY; + const int dilationX; + const bool useFP16; + const bool useWinograd; + mx::array weights; // OHWI format (only built when !useWinograd) + mx::array winogradWeights; // 4x4 domain U, valid only if useWinograd + const MLXWinograd::InputTransform winoInCfg; + const MLXWinograd::OutputUntransform winoOutCfg; + + ConvLayer() = delete; + ConvLayer(const ConvLayer&) = delete; + ConvLayer& operator=(const ConvLayer&) = delete; + + ConvLayer(const ConvLayerDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16_ = false) + : name(desc.name), + convYSize(desc.convYSize), + convXSize(desc.convXSize), + inChannels(desc.inChannels), + outChannels(desc.outChannels), + dilationY(desc.dilationY), + dilationX(desc.dilationX), + useFP16(useFP16_), + // Winograd path runs in fp16 too (no `!useFP16` gate). + useWinograd(mlxWinogradEnabled() + && convYSize==3 && convXSize==3 + && dilationY==1 && dilationX==1), + weights(useWinograd ? mx::array(0.0f) : toComputeDtypeMaterialized(convertConvWeightsOIHWtoOHWI(desc.weights, outChannels, inChannels, convYSize, convXSize), useFP16_)), + winogradWeights(useWinograd + ? MLXWinograd::makeWinogradWeights(desc.weights, outChannels, inChannels, useFP16_) + : mx::array(0.0f)) + ,winoInCfg(inCfg) + ,winoOutCfg(outCfg) + {} + + mx::array apply(const mx::array& input) const { + if(useWinograd) { + return MLXWinograd::winogradConv2d(input, winogradWeights, outChannels, winoInCfg, winoOutCfg, useFP16); + } + // MLX conv2d: input NHWC, weights OHWI + // Compute padding to maintain spatial dimensions (same padding) + int padY = (convYSize - 1) * dilationY / 2; + int padX = (convXSize - 1) * dilationX / 2; + + return mx::conv2d( + input, + weights, + /*stride=*/std::make_pair(1, 1), + /*padding=*/std::make_pair(padY, padX), + /*dilation=*/std::make_pair(dilationY, dilationX), + /*groups=*/1 + ); + } +}; + +struct BatchNormLayer { + const string name; + const int numChannels; + const int activation; + const bool useFP16; + mx::array mergedScale; // Shape: [C], always fp32 + mx::array mergedBias; // Shape: [C], always fp32 + + BatchNormLayer() = delete; + BatchNormLayer(const BatchNormLayer&) = delete; + BatchNormLayer& operator=(const BatchNormLayer&) = delete; + + // mergedScale/mergedBias storage is always fp32 to preserve dynamic + // range across the 25-block-deep b18c384 chain. The `useFP16` parameter + // is intentionally ignored. + static mx::array createArray1D(const std::vector& data, int size, bool /*useFP16*/) { + mx::Shape shape = {size}; + return mx::array(data.data(), shape, mx::float32); + } + + static std::vector getMergedScale(const BatchNormLayerDesc& desc) { + // If mergedScale is already computed, use it + if(!desc.mergedScale.empty()) { + return desc.mergedScale; + } + // Otherwise compute from mean/variance/scale/bias (for tests) + std::vector mergedScale(desc.numChannels); + for(int c = 0; c < desc.numChannels; c++) { + mergedScale[c] = desc.scale[c] / sqrt(desc.variance[c] + desc.epsilon); + } + return mergedScale; + } + + static std::vector getMergedBias(const BatchNormLayerDesc& desc) { + // If mergedBias is already computed, use it + if(!desc.mergedBias.empty()) { + return desc.mergedBias; + } + // Otherwise compute from mean/variance/scale/bias (for tests) + std::vector mergedBias(desc.numChannels); + for(int c = 0; c < desc.numChannels; c++) { + float ms = desc.scale[c] / sqrt(desc.variance[c] + desc.epsilon); + mergedBias[c] = desc.bias[c] - ms * desc.mean[c]; + } + return mergedBias; + } + + BatchNormLayer(const BatchNormLayerDesc& desc, int activationType, bool useFP16_ = false) + : name(desc.name), + numChannels(desc.numChannels), + activation(activationType), + useFP16(useFP16_), + mergedScale(createArray1D(getMergedScale(desc), desc.numChannels, useFP16_)), + mergedBias(createArray1D(getMergedBias(desc), desc.numChannels, useFP16_)) + {} + + mx::array apply(const mx::array& input, const mx::array& mask, bool useMask) const { + // input: NHWC [N, H, W, C] in compute dtype (fp16 or fp32). + // mask: NHW1 [N, H, W, 1] in compute dtype. + // mergedScale/mergedBias are always fp32; MLX type promotion lifts the + // multiply-add-activation chain to fp32 automatically (selective fp32 + // accumulation — defense against inf/nan in deep stacks). + // Mask multiply runs while activated is still fp32 (safe because mask is + // binary 0/1, so fp32*fp16 and fp16*fp16 round to bit-equal results). + // The single trailing astype-to-fp16 covers both useMask branches. + mx::array normalized = input * mergedScale + mergedBias; + mx::array activated = applyActivation(normalized, activation); + if(useMask) + activated = activated * mask; + // Cast back to fp16 so downstream layers see the expected compute dtype. + if(useFP16) activated = mx::astype(activated, mx::float16); + return activated; + } +}; + +struct MatMulLayer { + const string name; + const int inChannels; + const int outChannels; + mx::array weights; // [inC, outC] + + MatMulLayer() = delete; + MatMulLayer(const MatMulLayer&) = delete; + MatMulLayer& operator=(const MatMulLayer&) = delete; + + static mx::array createWeights(const MatMulLayerDesc& desc, bool useFP16) { + if(desc.inChannels > 0 && desc.outChannels > 0) { + // Original weights: [inC, outC] (column-major) + mx::Shape shape = {desc.inChannels, desc.outChannels}; + mx::array arr = mx::array(desc.weights.data(), shape, mx::float32); + return toComputeDtypeMaterialized(arr, useFP16); + } + std::vector dummy = {0.0f}; + mx::Shape shape = {1}; + return mx::array(dummy.data(), shape, mx::float32); + } + + MatMulLayer(const MatMulLayerDesc& desc, bool useFP16 = false) + : name(desc.name), + inChannels(desc.inChannels), + outChannels(desc.outChannels), + weights(createWeights(desc, useFP16)) + {} + + mx::array apply(const mx::array& input) const { + // input: [N, inC] + // output: [N, outC] + return mx::matmul(input, weights); + } +}; + +struct MatBiasLayer { + const string name; + const int numChannels; + mx::array bias; + + MatBiasLayer() = delete; + MatBiasLayer(const MatBiasLayer&) = delete; + MatBiasLayer& operator=(const MatBiasLayer&) = delete; + + static mx::array createBias(const MatBiasLayerDesc& desc, bool useFP16) { + mx::Shape shape = {desc.numChannels}; + mx::array arr = mx::array(desc.weights.data(), shape, mx::float32); + return toComputeDtypeMaterialized(arr, useFP16); + } + + MatBiasLayer(const MatBiasLayerDesc& desc, bool useFP16 = false) + : name(desc.name), + numChannels(desc.numChannels), + bias(createBias(desc, useFP16)) + {} + + mx::array apply(const mx::array& input) const { + return input + bias; + } +}; + +// -------------------------------------------------------------------------------------------------------------- +// Transformer layers (RMSNorm / attention / FFN). +// +// MLX operates on NHWC arrays [N, H, W, C]; C is the fastest (last) axis. The +// Eigen ground-truth (eigenbackend.cpp) operates on (C, W, H, N) col-major and +// reshapes spatial dims to a sequence. The math below mirrors Eigen exactly +// but stays in MLX's native NHWC layout: the spatial dims H,W are adjacent and +// C is last, so [N,H,W,C] reshapes contiguously to [N, seq, C] with +// seq = H*W = nnYLen*nnXLen. The mask [N,H,W,1] reshapes to [N, seq, 1]. +// +// All RMSNorm gamma/beta/weight buffers stay fp32 (like BatchNormLayer) to +// preserve dynamic range; MLX type promotion lifts the normalize chain to fp32 +// and the trailing astype returns to compute dtype. +// -------------------------------------------------------------------------------------------------------------- + +// Lightweight RMSNorm used inside transformer blocks (weight only, no bias, no +// spatial mode, no activation). Mirrors Eigen TransformerRMSNormLayer +// (eigenbackend.cpp 866-918). +struct TransformerRMSNormLayer { + const string name; + const int numChannels; + const float epsilon; + const bool useFP16; + mx::array weight; // [C], always fp32 + + TransformerRMSNormLayer() = delete; + TransformerRMSNormLayer(const TransformerRMSNormLayer&) = delete; + TransformerRMSNormLayer& operator=(const TransformerRMSNormLayer&) = delete; + + static mx::array createWeight(const TransformerRMSNormDesc& desc) { + mx::Shape shape = {desc.numChannels}; + return mx::array(desc.weight.data(), shape, mx::float32); + } + + TransformerRMSNormLayer(const TransformerRMSNormDesc& desc, bool useFP16_ = false) + : name(desc.name), + numChannels(desc.numChannels), + epsilon(desc.epsilon), + useFP16(useFP16_), + weight(createWeight(desc)) + {} + + // input/output NHWC [N, H, W, C] in compute dtype. mask NHW1 [N, H, W, 1]. + // Per-position RMSNorm across channels: out = x * rsqrt(mean(x^2) + eps) * weight, + // then masked to zero on invalid positions (Eigen zeroes masked positions). + mx::array apply(const mx::array& input, const mx::array& mask, bool useMask) const { + // Variance reduction in fp32 (Eigen computes sumSq in fp32, eigenbackend.cpp + // 904-910). mx::mean over an fp16 operand would accumulate the per-channel + // sum of squares in fp16; promote to fp32 first so the reduction is + // overflow/precision-safe before rsqrt. + mx::array inputF32 = useFP16 ? mx::astype(input, mx::float32) : input; + std::vector chAxis = {3}; + mx::array meanSq = mx::mean(inputF32 * inputF32, chAxis, /*keepdims=*/true); // [N,H,W,1], fp32 + mx::array rms = mx::rsqrt(meanSq + mx::array(epsilon)); + mx::array normalized = input * rms * weight; // weight fp32 promotes chain to fp32 + if(useMask) + normalized = normalized * mask; + if(useFP16) normalized = mx::astype(normalized, mx::float16); + return normalized; + } +}; + +// Full-featured RMSNorm for the trunk tip: gamma/beta, optional spatial mode, +// optional activation. Mirrors Eigen RMSNormLayer (eigenbackend.cpp 922-1022). +struct TransformerTrunkRMSNormLayer { + const string name; + const int numChannels; + const float epsilon; + const bool spatial; + const int activation; + const bool useFP16; + mx::array gamma; // [C], always fp32 + mx::array beta; // [C], always fp32 + + TransformerTrunkRMSNormLayer() = delete; + TransformerTrunkRMSNormLayer(const TransformerTrunkRMSNormLayer&) = delete; + TransformerTrunkRMSNormLayer& operator=(const TransformerTrunkRMSNormLayer&) = delete; + + static mx::array createVec(const std::vector& data, int size) { + mx::Shape shape = {size}; + return mx::array(data.data(), shape, mx::float32); + } + + TransformerTrunkRMSNormLayer(const RMSNormLayerDesc& desc, int activation_, bool useFP16_ = false) + : name(desc.name), + numChannels(desc.numChannels), + epsilon(desc.epsilon), + spatial(desc.spatial), + activation(activation_), + useFP16(useFP16_), + gamma(createVec(desc.gamma, desc.numChannels)), + beta(createVec(desc.beta, desc.numChannels)) + {} + + // input/output NHWC [N, H, W, C]. mask NHW1 [N, H, W, 1]. maskSum N111. + mx::array apply(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) const { + // The variance reduction (sum/mean of squares) MUST run in fp32 even when + // the compute dtype is fp16. Eigen computes sumSq in fp32 (eigenbackend.cpp + // 968-1002). In fp16 the spatial branch sums x^2 over up to + // nnXLen*nnYLen*numChannels (e.g. 19*19*256 ≈ 92k) elements; an fp16 + // accumulator saturates at 65504 and the rsqrt then yields 0/NaN/Inf which + // poisons the whole net (observed ~50% winrate / 99% top-policy error on the + // SiLU + spatial-rmsnorm-tip transformer). Promoting input to fp32 for the + // reduction makes the reduction overflow-free; the result rms is fp32 and the + // downstream normalize chain stays fp32 (gamma/beta are fp32 too) until the + // single trailing astype back to fp16. + mx::array inputF32 = useFP16 ? mx::astype(input, mx::float32) : input; + mx::array rms = mx::array(0.0f); + if(!spatial) { + // Non-spatial: per-position RMS across channels. + std::vector chAxis = {3}; + mx::array meanSq = mx::mean(inputF32 * inputF32, chAxis, /*keepdims=*/true); // [N,H,W,1], fp32 + rms = mx::rsqrt(meanSq + mx::array(epsilon)); + } + else { + // Spatial: per-batch RMS across all valid spatial positions AND channels. + // sumSq over valid positions only (Eigen skips masked positions); + // totalElts = count * numChannels with count = maskSum (valid spatial count). + std::vector spatialChAxes = {1, 2, 3}; + mx::array x2 = inputF32 * inputF32; // fp32 reduction operand (overflow-safe) + if(useMask) + x2 = x2 * mask; // mask is [N,H,W,1], broadcasts over channels + mx::array sumSq = mx::sum(x2, spatialChAxes, /*keepdims=*/true); // [N,1,1,1], fp32 + // maskSum is [N,1,1,1] (valid spatial position count); when !useMask it is + // a constant nnXLen*nnYLen full array, equally valid. + mx::array totalElts = maskSum * mx::array((float)numChannels); + rms = mx::rsqrt(sumSq / totalElts + mx::array(epsilon)); // [N,1,1,1], fp32 + } + mx::array normalized = input * rms * gamma + beta; // gamma/beta fp32 promote chain + normalized = applyActivation(normalized, activation); + if(useMask) + normalized = normalized * mask; + if(useFP16) normalized = mx::astype(normalized, mx::float16); + return normalized; + } +}; + +// Transformer attention block: GQA + optional RoPE + masked scaled-dot-product +// attention. Mirrors Eigen TransformerAttentionBlock (eigenbackend.cpp 1307-1588). +struct TransformerAttentionBlock { + const string name; + const int numHeads; + const int numKVHeads; + const int qHeadDim; + const int vHeadDim; + const bool useRope; + const bool learnableRope; + const int inChannels; + const bool useFP16; + const int nnXLen; + const int nnYLen; + + const TransformerRMSNormLayer preLN; + const MatMulLayer qProj; + const MatMulLayer kProj; + const MatMulLayer vProj; + const MatMulLayer outProj; + + // Precomputed RoPE cos/sin tables, materialized as MLX arrays (always fp32). + // Learnable layout (after reshaping for our use): [numKVHeads, numPairs, seq]. + // Fixed layout: [numPairs, seq]. ropeNumPairs = qHeadDim/2. + int ropeNumPairs; + mx::array ropeCos; // valid iff useRope + mx::array ropeSin; // valid iff useRope + + TransformerAttentionBlock() = delete; + TransformerAttentionBlock(const TransformerAttentionBlock&) = delete; + TransformerAttentionBlock& operator=(const TransformerAttentionBlock&) = delete; + + static mx::array makeRopeTable(const std::vector& table, bool learnable, int numKVHeads, int numPairs, int seq) { + if(table.empty()) + return mx::array(0.0f); + if(learnable) { + mx::Shape shape = {numKVHeads, numPairs, seq}; + return mx::array(table.data(), shape, mx::float32); + } + mx::Shape shape = {numPairs, seq}; + return mx::array(table.data(), shape, mx::float32); + } + + TransformerAttentionBlock(const TransformerAttentionDesc& desc, int nnX, int nnY, bool useFP16_ = false) + : name(desc.name), + numHeads(desc.numHeads), + numKVHeads(desc.numKVHeads), + qHeadDim(desc.qHeadDim), + vHeadDim(desc.vHeadDim), + useRope(desc.useRope), + learnableRope(desc.learnableRope), + inChannels(desc.qProj.inChannels), + useFP16(useFP16_), + nnXLen(nnX), + nnYLen(nnY), + preLN(desc.preLN, useFP16_), + qProj(desc.qProj, useFP16_), + kProj(desc.kProj, useFP16_), + vProj(desc.vProj, useFP16_), + outProj(desc.outProj, useFP16_), + ropeNumPairs(0), + ropeCos(mx::array(0.0f)), + ropeSin(mx::array(0.0f)) + { + if(useRope) { + ropeNumPairs = qHeadDim / 2; + int seq = nnX * nnY; + int paddedNNXYLen = seq; + std::vector cosTable, sinTable; + desc.computeRopeCosSin(nnX, nnY, paddedNNXYLen, cosTable, sinTable); + ropeCos = makeRopeTable(cosTable, learnableRope, numKVHeads, ropeNumPairs, seq); + ropeSin = makeRopeTable(sinTable, learnableRope, numKVHeads, ropeNumPairs, seq); + } + } + + // Apply RoPE to a projection laid out as [N, seq, numBufHeads, headDim]. + // RoPE rotates interleaved channel pairs (2p, 2p+1) within each head. + // cos/sin tables map per (kvHead-of-this-head, pair, seq). Returns rotated + // array in the same shape. headDim == qHeadDim here. + mx::array applyRope(const mx::array& proj, int numBufHeads) const { + int seq = nnXLen * nnYLen; + int batchSize = proj.shape()[0]; + // Split even/odd channel pairs. proj: [N, seq, numBufHeads, qHeadDim]. + // Reshape to [N, seq, numBufHeads, ropeNumPairs, 2] then take [...,0]/[...,1]. + mx::Shape pairShape = {batchSize, seq, numBufHeads, ropeNumPairs, 2}; + mx::array pairs = mx::reshape(proj, pairShape); + mx::array x0 = mx::squeeze(mx::slice(pairs, {0,0,0,0,0}, {batchSize, seq, numBufHeads, ropeNumPairs, 1}), std::vector{4}); // [N,seq,H,pairs] + mx::array x1 = mx::squeeze(mx::slice(pairs, {0,0,0,0,1}, {batchSize, seq, numBufHeads, ropeNumPairs, 2}), std::vector{4}); // [N,seq,H,pairs] + + // Build cos/sin broadcastable to [N, seq, numBufHeads, ropeNumPairs]. + // Source tables: learnable [numKVHeads, pairs, seq]; fixed [pairs, seq]. + // Target per-head index kvh = h * numKVHeads / numBufHeads (Eigen mapping). + mx::array cosB = mx::array(0.0f); + mx::array sinB = mx::array(0.0f); + if(learnableRope) { + // Expand each KV head to the Q heads that map to it. For head h, + // kvh = h * numKVHeads / numBufHeads. With numBufHeads a multiple of + // numKVHeads, this is a contiguous block expansion: repeat each kv head + // (numBufHeads / numKVHeads) times along the head axis. + int groupSize = numBufHeads / numKVHeads; + // ropeCos: [numKVHeads, pairs, seq] -> [numBufHeads, pairs, seq] + mx::array cosHeads = mx::repeat(ropeCos, groupSize, /*axis=*/0); // [numBufHeads, pairs, seq] + mx::array sinHeads = mx::repeat(ropeSin, groupSize, /*axis=*/0); + // -> [seq, numBufHeads, pairs] then expand batch axis + cosHeads = mx::transpose(cosHeads, std::vector{2, 0, 1}); // [seq, numBufHeads, pairs] + sinHeads = mx::transpose(sinHeads, std::vector{2, 0, 1}); + cosB = mx::expand_dims(cosHeads, 0); // [1, seq, numBufHeads, pairs] + sinB = mx::expand_dims(sinHeads, 0); + } + else { + // ropeCos: [pairs, seq] -> [seq, pairs], broadcast over heads. + mx::array cosSP = mx::transpose(ropeCos, std::vector{1, 0}); // [seq, pairs] + mx::array sinSP = mx::transpose(ropeSin, std::vector{1, 0}); + cosB = mx::expand_dims(mx::expand_dims(cosSP, 1), 0); // [1, seq, 1, pairs] + sinB = mx::expand_dims(mx::expand_dims(sinSP, 1), 0); + } + if(useFP16) { + cosB = mx::astype(cosB, mx::float16); + sinB = mx::astype(sinB, mx::float16); + } + + mx::array r0 = x0 * cosB - x1 * sinB; // [N,seq,H,pairs] + mx::array r1 = x0 * sinB + x1 * cosB; + // Re-interleave: stack along new last axis -> [N,seq,H,pairs,2] -> [N,seq,H,qHeadDim]. + mx::array stacked = mx::stack(std::vector{r0, r1}, /*axis=*/4); + mx::Shape outShape = {batchSize, seq, numBufHeads, ropeNumPairs * 2}; + return mx::reshape(stacked, outShape); + } + + // input/output NHWC [N, H, W, C]. mask NHW1 [N, H, W, 1]. + mx::array apply(const mx::array& trunk, const mx::array& mask, bool useMask) const { + int batchSize = trunk.shape()[0]; + int seq = nnXLen * nnYLen; + int kvGroupSize = numHeads / numKVHeads; + float scale = 1.0f / sqrtf((float)qHeadDim); + + // Step 1: preLN RMSNorm (masks output to zero on invalid positions). + mx::array normed = preLN.apply(trunk, mask, useMask); // [N,H,W,C] + + // Flatten spatial dims to a sequence: [N, seq, C]. + mx::Shape seqShape = {batchSize, seq, inChannels}; + mx::array normedSeq = mx::reshape(normed, seqShape); + + // Step 2: Q/K/V projections. weights are [inC, outC]; x[N,seq,inC] @ W. + mx::array q = mx::matmul(normedSeq, qProj.weights); // [N, seq, numHeads*qHeadDim] + mx::array k = mx::matmul(normedSeq, kProj.weights); // [N, seq, numKVHeads*qHeadDim] + mx::array v = mx::matmul(normedSeq, vProj.weights); // [N, seq, numKVHeads*vHeadDim] + + // Reshape to per-head: [N, seq, numHeads, qHeadDim] etc. + q = mx::reshape(q, mx::Shape{batchSize, seq, numHeads, qHeadDim}); + k = mx::reshape(k, mx::Shape{batchSize, seq, numKVHeads, qHeadDim}); + v = mx::reshape(v, mx::Shape{batchSize, seq, numKVHeads, vHeadDim}); + + // Step 3: RoPE on Q and K. + if(useRope) { + q = applyRope(q, numHeads); + k = applyRope(k, numKVHeads); + } + + // Move head axis ahead of seq: [N, numHeads, seq, headDim]. + q = mx::transpose(q, std::vector{0, 2, 1, 3}); // [N, numHeads, seq, qHeadDim] + k = mx::transpose(k, std::vector{0, 2, 1, 3}); // [N, numKVHeads, seq, qHeadDim] + v = mx::transpose(v, std::vector{0, 2, 1, 3}); // [N, numKVHeads, seq, vHeadDim] + + // Expand KV heads to match Q heads (GQA): repeat each kv head kvGroupSize + // times so head h uses kv head h/kvGroupSize (Eigen kvh = h / kvGroupSize). + if(kvGroupSize > 1) { + k = mx::repeat(k, kvGroupSize, /*axis=*/1); // [N, numHeads, seq, qHeadDim] + v = mx::repeat(v, kvGroupSize, /*axis=*/1); // [N, numHeads, seq, vHeadDim] + } + + // Step 4: scores = scale * Q @ K^T -> [N, numHeads, seq(query), seq(key)]. + // Keep the whole attention chain that follows (scores, softmax, attn@V, + // out-proj, and the residual stream) in the compute dtype. MLX C++ has no + // weak scalars: mx::array(scale) is a strong float32 array, so + // `matmul(q,kT) * mx::array(scale)` would promote scores -- and everything + // downstream, since each transformer block adds its output into the trunk -- + // to fp32. Cast scale to the compute dtype so the fp16 path stays fp16 + // end-to-end (maximize-fp16 on the GPU path; fp16 attention accuracy is gated + // by testgpuerror). The pooling/BN/RMSNorm layers handle this same MLX + // promotion rule by casting their fp32 intermediates back with a trailing + // astype; attention does it here at the source instead. + mx::array kT = mx::transpose(k, std::vector{0, 1, 3, 2}); // [N, numHeads, qHeadDim, seq] + mx::array scaleArr = useFP16 ? mx::astype(mx::array(scale), mx::float16) : mx::array(scale); + mx::array scores = mx::matmul(q, kT) * scaleArr; + + // Masked softmax over the key axis (last). Keys with mask==0 get -inf so + // they contribute 0; fully-masked query rows are zeroed afterward to match + // Eigen (which zeroes masked query rows entirely). + if(useMask) { + // keyMask: [N, 1, 1, seq] broadcasting over heads and query positions. + // Use 1e4 (not 1e9): representable in fp16, and exp(-1e4) underflows to 0 + // cleanly, avoiding inf arithmetic. Scores are O(1)*scale, so 1e4 fully + // suppresses masked keys. The board is never fully masked, so no query row + // is all-masked -> softmax never sees an all -inf row (no NaN). + mx::array maskSeq = mx::reshape(mask, mx::Shape{batchSize, 1, 1, seq}); // [N,1,1,seq] + mx::array neg = (mx::array(1.0f) - maskSeq) * mx::array(1e4f); + if(useFP16) neg = mx::astype(neg, mx::float16); + scores = scores - neg; + } + mx::array weights = mx::softmax(scores, /*axis=*/3, /*precise=*/true); // [N, numHeads, seq, seq] + + // attnOut = weights @ V -> [N, numHeads, seq(query), vHeadDim]. + mx::array attn = mx::matmul(weights, v); + + if(useMask) { + // Zero out fully-masked query rows (Eigen zeroes masked queries). With a + // masked-key softmax a masked query still produces a normalized row, but + // its residual contribution is gated by the trunk residual mask below, so + // this extra gating is for parity; multiply by query mask. + mx::array qMask = mx::reshape(mask, mx::Shape{batchSize, 1, seq, 1}); // [N,1,seq,1] + attn = attn * qMask; + } + + // Merge heads back: [N, numHeads, seq, vHeadDim] -> [N, seq, numHeads*vHeadDim]. + attn = mx::transpose(attn, std::vector{0, 2, 1, 3}); // [N, seq, numHeads, vHeadDim] + attn = mx::reshape(attn, mx::Shape{batchSize, seq, numHeads * vHeadDim}); + + // Step 5: output projection -> [N, seq, outC] (outC == inChannels). + mx::array out = mx::matmul(attn, outProj.weights); // [N, seq, inChannels] + mx::array outSpatial = mx::reshape(out, mx::Shape{batchSize, nnYLen, nnXLen, inChannels}); + + // Step 6: residual + mask. (Eigen adds outProj * maskVal to trunk.) + if(useMask) + outSpatial = outSpatial * mask; + return trunk + outSpatial; + } +}; + +// Transformer FFN block: SwiGLU. Mirrors Eigen TransformerFFNBlock +// (eigenbackend.cpp 1592-1707). Non-SwiGLU is unsupported (Eigen throws too). +struct TransformerFFNBlock { + const string name; + const int numChannels; + const int ffnChannels; + const bool useSwiGLU; + const bool useFP16; + const int nnXLen; + const int nnYLen; + + const TransformerRMSNormLayer preLN; + const MatMulLayer linear1; + unique_ptr linearGate; + const MatMulLayer linear2; + + TransformerFFNBlock() = delete; + TransformerFFNBlock(const TransformerFFNBlock&) = delete; + TransformerFFNBlock& operator=(const TransformerFFNBlock&) = delete; + + TransformerFFNBlock(const TransformerFFNDesc& desc, int nnX, int nnY, bool useFP16_ = false) + : name(desc.name), + numChannels(desc.numChannels), + ffnChannels(desc.ffnChannels), + useSwiGLU(desc.useSwiGLU), + useFP16(useFP16_), + nnXLen(nnX), + nnYLen(nnY), + preLN(desc.preLN, useFP16_), + linear1(desc.linear1, useFP16_), + linear2(desc.linear2, useFP16_) + { + if(useSwiGLU) + linearGate = make_unique(desc.linearGate, useFP16_); + else + throw StringError("MLX backend: Non-SwiGLU transformer FFN is not supported"); + } + + // input/output NHWC [N, H, W, C]. mask NHW1 [N, H, W, 1]. + mx::array apply(const mx::array& trunk, const mx::array& mask, bool useMask) const { + int batchSize = trunk.shape()[0]; + int seq = nnXLen * nnYLen; + + // Step 1: preLN RMSNorm. + mx::array normed = preLN.apply(trunk, mask, useMask); // [N,H,W,C] + mx::array normedSeq = mx::reshape(normed, mx::Shape{batchSize, seq, numChannels}); + + // Step 2/3: SwiGLU = SiLU(linear1(x)) * linearGate(x). + mx::array a = mx::matmul(normedSeq, linear1.weights); // [N, seq, ffnChannels] + mx::array gate = mx::matmul(normedSeq, linearGate->weights); + mx::array swiglu = (a * mx::sigmoid(a)) * gate; // SiLU(a) * gate + + // Step 4: linear2 -> [N, seq, numChannels]. + mx::array out = mx::matmul(swiglu, linear2.weights); + mx::array outSpatial = mx::reshape(out, mx::Shape{batchSize, nnYLen, nnXLen, numChannels}); + + // Step 5: residual + mask. + if(useMask) + outSpatial = outSpatial * mask; + return trunk + outSpatial; + } +}; + +// Global pooling: computes [mean, mean * (sqrt(maskSum) - 14) * 0.1, max] concatenated along channel axis +static mx::array applyGlobalPooling(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) { + // input: NHWC [N, H, W, C] + // mask: NHW1 [N, H, W, 1] + // maskSum: N111 [N, 1, 1, 1] + + // Keep the whole pooling in the input's compute dtype (fp16 in fp16 mode) so + // the pooled output matches the trunk and the downstream gpool-bias matmul + // stays fp16. maskSum is the valid-position count (<=361 for 19x19), exact in + // fp16. The fp32-accumulation variant measured negligible accuracy gain + // (PR#1199 M1), so prefer fp16 consistency (minimize fp32 when useFP16). + const auto dt = input.dtype(); + std::vector spatialAxes = {1, 2}; + mx::array spatialSum = mx::sum(input, spatialAxes, /*keepdims=*/true); // [N, 1, 1, C], dt + mx::array maskSumDt = mx::astype(maskSum, dt); + + // Mean = sum / maskSum + mx::array mean = spatialSum / maskSumDt; // [N, 1, 1, C], dt + + // sqrt(maskSum) - 14) * 0.1 (scalar literals promote to fp32; cast result back to dt) + mx::array sqrtMaskSum = mx::sqrt(maskSumDt); + mx::array scaleFactor = (sqrtMaskSum - mx::array(14.0f)) * mx::array(0.1f); + mx::array meanScaled = mx::astype(mean * scaleFactor, dt); // dt + + // Max - masked positions pushed down in fp32 (1e9 overflows fp16 -> inf/NaN), + // then cast back to dt. Skip the mask adjustment when useMask=false. + mx::array maxVal = useMask + ? mx::max(input - (mx::array(1.0f) - mask) * mx::array(1e9f), spatialAxes, /*keepdims=*/true) + : mx::max(input, spatialAxes, /*keepdims=*/true); + maxVal = mx::astype(maxVal, dt); // dt + + // Concatenate along channel axis (axis 3 for NHWC); all components are dt + std::vector concatInputs = {mean, meanScaled, maxVal}; + return mx::concatenate(concatInputs, /*axis=*/3); +} + +// Value head pooling: computes [mean, mean * (sqrt(maskSum) - 14) * 0.1, mean * ((sqrt-14)^2 * 0.01 - 0.1)] +static mx::array applyValueHeadPooling(const mx::array& input, const mx::array& maskSum) { + // input: NHWC [N, H, W, C] + // maskSum: N111 [N, 1, 1, 1] + + // fp16-consistent (see applyGlobalPooling): keep the value-head pooling in the + // compute dtype so the v2 matmul stays fp16. + const auto dt = input.dtype(); + std::vector spatialAxes = {1, 2}; + mx::array spatialSum = mx::sum(input, spatialAxes, /*keepdims=*/true); // dt + mx::array maskSumDt = mx::astype(maskSum, dt); + mx::array mean = spatialSum / maskSumDt; // dt + + mx::array sqrtMaskSum = mx::sqrt(maskSumDt); + mx::array diff = sqrtMaskSum - mx::array(14.0f); + mx::array meanScaled1 = mx::astype(mean * diff * mx::array(0.1f), dt); + mx::array meanScaled2 = mx::astype(mean * (diff * diff * mx::array(0.01f) - mx::array(0.1f)), dt); + + std::vector concatInputs = {mean, meanScaled1, meanScaled2}; + return mx::concatenate(concatInputs, /*axis=*/3); +} + +// Residual Block +// Map KataGo activation enum -> kWinoOutputSourceBNAct ACT id; -1 = unsupported +// (caller must fall back to the unfused path). +static int mapEpilogueAct(int activation) { + switch(activation) { + case ACTIVATION_IDENTITY: return 0; + case ACTIVATION_MISH: return 1; + case ACTIVATION_RELU: return 2; + default: return -1; // SILU etc.: not fused + } +} + +// conv -> BN+activation, fused into the Winograd untransform when the conv is +// Winograd, the batch is unmasked, and the activation is supported. Otherwise +// the exact pre-existing decomposition. +static mx::array fusedConvBNAct(const ConvLayer& conv, const BatchNormLayer& bn, + const mx::array& x, const mx::array& mask, bool useMask) { + int act = mapEpilogueAct(bn.activation); + if(conv.useWinograd && !useMask && act >= 0) { + return MLXWinograd::winogradConv2d( + x, conv.winogradWeights, conv.outChannels, conv.winoInCfg, conv.winoOutCfg, conv.useFP16, + MLXWinograd::Epilogue::bnAct(bn.mergedScale, bn.mergedBias, act)); + } + return bn.apply(conv.apply(x), mask, useMask); +} + +// conv -> (residual + conv), fused into the Winograd untransform when the conv +// is Winograd and the batch is unmasked. Otherwise the exact decomposition. +static mx::array fusedConvResidual(const ConvLayer& conv, const mx::array& x, + const mx::array& resid, bool useMask) { + if(conv.useWinograd && !useMask) { + return MLXWinograd::winogradConv2d( + x, conv.winogradWeights, conv.outChannels, conv.winoInCfg, conv.winoOutCfg, conv.useFP16, + MLXWinograd::Epilogue::residual(resid)); + } + return resid + conv.apply(x); +} + +// conv -> (+ broadcast bias) -> BN+activation, fused into the Winograd untransform +// when the conv is Winograd, unmasked, and the activation is supported. gbias is +// the gpool bias [N,outC] (T). Otherwise the exact pre-existing decomposition. +static mx::array fusedConvBiasBNAct(const ConvLayer& conv, const BatchNormLayer& bn, + const mx::array& x, const mx::array& gbias, + const mx::array& mask, bool useMask) { + int act = mapEpilogueAct(bn.activation); + if(conv.useWinograd && !useMask && act >= 0) { + return MLXWinograd::winogradConv2d( + x, conv.winogradWeights, conv.outChannels, conv.winoInCfg, conv.winoOutCfg, conv.useFP16, + MLXWinograd::Epilogue::biasBNAct(gbias, bn.mergedScale, bn.mergedBias, act)); + } + mx::Shape bShape = {static_cast(gbias.shape()[0]), 1, 1, static_cast(gbias.shape()[1])}; + return bn.apply(conv.apply(x) + mx::reshape(gbias, bShape), mask, useMask); +} + +struct ResidualBlock { + const string name; + const BatchNormLayer preBN; + const ConvLayer regularConv; + const BatchNormLayer midBN; + const ConvLayer finalConv; + + ResidualBlock() = delete; + ResidualBlock(const ResidualBlock&) = delete; + ResidualBlock& operator=(const ResidualBlock&) = delete; + + ResidualBlock(const ResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : name(desc.name), + preBN(desc.preBN, desc.preActivation.activation, useFP16), + regularConv(desc.regularConv, inCfg, outCfg, useFP16), + midBN(desc.midBN, desc.midActivation.activation, useFP16), + finalConv(desc.finalConv, inCfg, outCfg, useFP16) + {} + + mx::array apply(const mx::array& input, const mx::array& mask, bool useMask) const { + mx::array out = preBN.apply(input, mask, useMask); + out = fusedConvBNAct(regularConv, midBN, out, mask, useMask); // fuse regularConv -> midBN + return fusedConvResidual(finalConv, out, input, useMask); // fuse finalConv -> input + out + } +}; + +// Global Pooling Residual Block +struct GlobalPoolingResidualBlock { + const string name; + const BatchNormLayer preBN; + const ConvLayer regularConv; + const ConvLayer gpoolConv; + const BatchNormLayer gpoolBN; + const MatMulLayer gpoolToBiasMul; + const BatchNormLayer midBN; + const ConvLayer finalConv; + + GlobalPoolingResidualBlock() = delete; + GlobalPoolingResidualBlock(const GlobalPoolingResidualBlock&) = delete; + GlobalPoolingResidualBlock& operator=(const GlobalPoolingResidualBlock&) = delete; + + GlobalPoolingResidualBlock(const GlobalPoolingResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : name(desc.name), + preBN(desc.preBN, desc.preActivation.activation, useFP16), + regularConv(desc.regularConv, inCfg, outCfg, useFP16), + gpoolConv(desc.gpoolConv, inCfg, outCfg, useFP16), + gpoolBN(desc.gpoolBN, desc.gpoolActivation.activation, useFP16), + gpoolToBiasMul(desc.gpoolToBiasMul, useFP16), + midBN(desc.midBN, desc.midActivation.activation, useFP16), + finalConv(desc.finalConv, inCfg, outCfg, useFP16) + {} + + mx::array apply(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) const { + mx::array preOut = preBN.apply(input, mask, useMask); + + // Global pooling path -> broadcast bias [N, outC] + mx::array gpoolOut = gpoolConv.apply(preOut); + gpoolOut = gpoolBN.apply(gpoolOut, mask, useMask); + mx::array pooled = applyGlobalPooling(gpoolOut, mask, maskSum, useMask); + std::vector squeezeAxes = {1, 2}; + mx::array bias = gpoolToBiasMul.apply(mx::squeeze(pooled, squeezeAxes)); // [N, outC], T + + // Fuse regularConv -> (+bias) -> midBN(BN+act), then finalConv -> input + out. + mx::array combined = fusedConvBiasBNAct(regularConv, midBN, preOut, bias, mask, useMask); + return fusedConvResidual(finalConv, combined, input, useMask); + } +}; + +// Nested Bottleneck Residual Block (simplified - forward declaration for recursive types) +struct NestedBottleneckResidualBlock; + +// Block variant type for trunk +struct BlockVariant { + enum Type { REGULAR, GLOBAL_POOLING, NESTED_BOTTLENECK, TRANSFORMER_ATTENTION, TRANSFORMER_FFN }; + Type type; + unique_ptr regular; + unique_ptr globalPooling; + unique_ptr nestedBottleneck; + unique_ptr attention; + unique_ptr ffn; + + BlockVariant(const ResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : type(REGULAR), regular(make_unique(desc, inCfg, outCfg, useFP16)) {} + + BlockVariant(const GlobalPoolingResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : type(GLOBAL_POOLING), globalPooling(make_unique(desc, inCfg, outCfg, useFP16)) {} + + // Transformer blocks have no convolutions, so they take board dims (nnX, nnY) + // instead of Winograd transform configs. + BlockVariant(const TransformerAttentionDesc& desc, int nnX, int nnY, bool useFP16 = false) + : type(TRANSFORMER_ATTENTION), attention(make_unique(desc, nnX, nnY, useFP16)) {} + + BlockVariant(const TransformerFFNDesc& desc, int nnX, int nnY, bool useFP16 = false) + : type(TRANSFORMER_FFN), ffn(make_unique(desc, nnX, nnY, useFP16)) {} + + // Forward declaration - defined after NestedBottleneckResidualBlock. + // Takes nnX/nnY so any transformer blocks nested inside the bottleneck can + // precompute their RoPE tables. + BlockVariant(const NestedBottleneckResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + int nnX, + int nnY, + bool useFP16); + + mx::array apply(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) const; +}; + +struct NestedBottleneckResidualBlock { + const string name; + const BatchNormLayer preBN; + const ConvLayer preConv; + vector blocks; + const BatchNormLayer postBN; + const ConvLayer postConv; + + NestedBottleneckResidualBlock() = delete; + NestedBottleneckResidualBlock(const NestedBottleneckResidualBlock&) = delete; + NestedBottleneckResidualBlock& operator=(const NestedBottleneckResidualBlock&) = delete; + + NestedBottleneckResidualBlock(const NestedBottleneckResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + int nnX, + int nnY, + bool useFP16 = false) + : name(desc.name), + preBN(desc.preBN, desc.preActivation.activation, useFP16), + preConv(desc.preConv, inCfg, outCfg, useFP16), + postBN(desc.postBN, desc.postActivation.activation, useFP16), + postConv(desc.postConv, inCfg, outCfg, useFP16) + { + // Mirror parseResidualBlockStack (desc.cpp), which accepts the same five + // block kinds inside a nested bottleneck as in the trunk - including a + // nested bottleneck within a nested bottleneck. Keep this in sync with the + // trunk's block loop below; an unhandled kind is a loud bug, not a silent + // no-op block. + for(size_t i = 0; i < desc.blocks.size(); i++) { + int blockKind = desc.blocks[i].first; + if(blockKind == ORDINARY_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, useFP16); + } + else if(blockKind == GLOBAL_POOLING_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, useFP16); + } + else if(blockKind == NESTED_BOTTLENECK_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, nnX, nnY, useFP16); + } + else if(blockKind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), nnX, nnY, useFP16); + } + else if(blockKind == TRANSFORMER_FFN_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), nnX, nnY, useFP16); + } + else { + ASSERT_UNREACHABLE; + } + } + } + + mx::array apply(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) const { + mx::array out = preBN.apply(input, mask, useMask); + out = preConv.apply(out); + for(const auto& block : blocks) { + out = block.apply(out, mask, maskSum, useMask); + } + out = postBN.apply(out, mask, useMask); + return fusedConvResidual(postConv, out, input, useMask); // fuse postConv -> input + out + } +}; + +// Define BlockVariant constructor for NestedBottleneckResidualBlock now that it's complete +BlockVariant::BlockVariant(const NestedBottleneckResidualBlockDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + int nnX, + int nnY, + bool useFP16) + : type(NESTED_BOTTLENECK), nestedBottleneck(make_unique(desc, inCfg, outCfg, nnX, nnY, useFP16)) {} + +mx::array BlockVariant::apply(const mx::array& input, const mx::array& mask, const mx::array& maskSum, bool useMask) const { + switch(type) { + case REGULAR: + return regular->apply(input, mask, useMask); + case GLOBAL_POOLING: + return globalPooling->apply(input, mask, maskSum, useMask); + case NESTED_BOTTLENECK: + return nestedBottleneck->apply(input, mask, maskSum, useMask); + case TRANSFORMER_ATTENTION: + return attention->apply(input, mask, useMask); + case TRANSFORMER_FFN: + return ffn->apply(input, mask, useMask); + default: + // All BlockVariant::Type values are handled above. Reaching the default + // means the tagged union holds an unrecognized type - fail loudly rather + // than silently returning the input (an identity no-op block). The + // default also satisfies -Wswitch-default (see cpp/CMakeLists.txt). + ASSERT_UNREACHABLE; + } +} + +// SGF Metadata Encoder +struct SGFMetadataEncoder { + const int metaEncoderVersion; + const int numInputMetaChannels; + const MatMulLayer mul1; + const MatBiasLayer bias1; + const int act1; + const MatMulLayer mul2; + const MatBiasLayer bias2; + const int act2; + const MatMulLayer mul3; + + SGFMetadataEncoder() = delete; + SGFMetadataEncoder(const SGFMetadataEncoder&) = delete; + SGFMetadataEncoder& operator=(const SGFMetadataEncoder&) = delete; + + SGFMetadataEncoder(const SGFMetadataEncoderDesc& desc, bool useFP16 = false) + : metaEncoderVersion(desc.metaEncoderVersion), + numInputMetaChannels(desc.numInputMetaChannels), + mul1(desc.mul1, useFP16), + bias1(desc.bias1, useFP16), + act1(desc.act1.activation), + mul2(desc.mul2, useFP16), + bias2(desc.bias2, useFP16), + act2(desc.act2.activation), + mul3(desc.mul3, useFP16) + {} + + mx::array apply(const mx::array& metaInput) const { + // Fuse matmul + bias with addmm for better performance + mx::array out = matmulBias(metaInput, mul1.weights, bias1.bias); + out = applyActivation(out, act1); + out = matmulBias(out, mul2.weights, bias2.bias); + out = applyActivation(out, act2); + out = mul3.apply(out); // Last layer has no bias + return out; + } +}; + +// Trunk +struct Trunk { + const string name; + const int trunkNumChannels; + const int trunkNormKind; + const ConvLayer initialConv; + const MatMulLayer initialMatMul; + unique_ptr sgfMetadataEncoder; + vector blocks; + // Exactly one of these is populated depending on trunkNormKind. For + // TRUNK_NORM_KIND_RMSNORM the trunkTipBN desc on disk is empty (size-0 + // weights), so constructing a BatchNormLayer from it would create empty + // arrays that broadcast against the spatial trunk and crash. Mirrors Eigen's + // Trunk (eigenbackend.cpp 1875-1880) which selects BN vs RMSNorm at load. + unique_ptr trunkTipBN; + unique_ptr trunkTipRMSNorm; + + Trunk() = delete; + Trunk(const Trunk&) = delete; + Trunk& operator=(const Trunk&) = delete; + + Trunk(const TrunkDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + int nnX, + int nnY, + bool useFP16 = false) + : name(desc.name), + trunkNumChannels(desc.trunkNumChannels), + trunkNormKind(desc.trunkNormKind), + initialConv(desc.initialConv, inCfg, outCfg, useFP16), + initialMatMul(desc.initialMatMul, useFP16) + { + // Trunk tip normalization: BatchNorm for standard nets, RMSNorm for + // transformer nets. Only the desc matching trunkNormKind was parsed. + if(desc.trunkNormKind == TRUNK_NORM_KIND_STANDARD) { + trunkTipBN = make_unique(desc.trunkTipBN, desc.trunkTipActivation.activation, useFP16); + } + else { + trunkTipRMSNorm = make_unique(desc.trunkTipRMSNorm, desc.trunkTipActivation.activation, useFP16); + } + + if(desc.sgfMetadataEncoder.metaEncoderVersion > 0 && desc.sgfMetadataEncoder.numInputMetaChannels > 0) { + sgfMetadataEncoder = make_unique(desc.sgfMetadataEncoder, useFP16); + } + + for(size_t i = 0; i < desc.blocks.size(); i++) { + int blockKind = desc.blocks[i].first; + if(blockKind == ORDINARY_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, useFP16); + } + else if(blockKind == GLOBAL_POOLING_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, useFP16); + } + else if(blockKind == NESTED_BOTTLENECK_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), inCfg, outCfg, nnX, nnY, useFP16); + } + else if(blockKind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), nnX, nnY, useFP16); + } + else if(blockKind == TRANSFORMER_FFN_BLOCK_KIND) { + blocks.emplace_back(*static_cast(desc.blocks[i].second.get()), nnX, nnY, useFP16); + } + else { + // parseResidualBlockStack (desc.cpp) rejects any other kind at load, + // so reaching here means a new block kind was added without backend + // support - fail loudly instead of silently dropping the block. + ASSERT_UNREACHABLE; + } + } + } + + mx::array apply( + const mx::array& input, + const mx::array& inputGlobal, + const mx::array* inputMeta, + const mx::array& mask, + const mx::array& maskSum, + bool useMask + ) const { + // Initial conv + mx::array trunk = initialConv.apply(input); + + // Add global input bias + mx::array globalBias = initialMatMul.apply(inputGlobal); + // Reshape from [N, C] to [N, 1, 1, C] for broadcasting + mx::Shape globalBiasShape = {static_cast(globalBias.shape()[0]), 1, 1, static_cast(globalBias.shape()[1])}; + globalBias = mx::reshape(globalBias, globalBiasShape); + trunk = trunk + globalBias; + + // Add SGF metadata if present + if(sgfMetadataEncoder && inputMeta != nullptr) { + mx::array metaBias = sgfMetadataEncoder->apply(*inputMeta); + mx::Shape metaBiasShape = {static_cast(metaBias.shape()[0]), 1, 1, static_cast(metaBias.shape()[1])}; + metaBias = mx::reshape(metaBias, metaBiasShape); + trunk = trunk + metaBias; + } + + // Apply mask - skip when useMask=false (all positions valid) + if(useMask) + trunk = trunk * mask; + + // Apply residual blocks + for(const auto& block : blocks) { + trunk = block.apply(trunk, mask, maskSum, useMask); + } + + // Final trunk-tip normalization + activation: BatchNorm or RMSNorm. + if(trunkNormKind == TRUNK_NORM_KIND_STANDARD) { + trunk = trunkTipBN->apply(trunk, mask, useMask); + } + else { + trunk = trunkTipRMSNorm->apply(trunk, mask, maskSum, useMask); + } + + return trunk; + } +}; + +// Policy Head +struct PolicyHead { + const string name; + const int modelVersion; + const ConvLayer p1Conv; + const ConvLayer g1Conv; + const BatchNormLayer g1BN; + const MatMulLayer gpoolToBiasMul; + const BatchNormLayer p1BN; + const ConvLayer p2Conv; + const MatMulLayer gpoolToPassMul; + // v15+ two-layer pass head: gpoolToPassMul (input -> hidden) -> + // gpoolToPassBias -> passActivation -> gpoolToPassMul2 (hidden -> output). + // Pre-v15 models use a single matmul (gpoolToPassMul: input -> output) and + // these three fields stay empty / zero. Mirrors the v15+ branch of + // PolicyHeadDesc::PolicyHeadDesc in desc.cpp and Metal's + // policyHeadDescToSwift in metalbackend.cpp. + const std::optional gpoolToPassBias; + const int passActivationType; + const std::optional gpoolToPassMul2; + + PolicyHead() = delete; + PolicyHead(const PolicyHead&) = delete; + PolicyHead& operator=(const PolicyHead&) = delete; + + PolicyHead(const PolicyHeadDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : name(desc.name), + modelVersion(desc.modelVersion), + p1Conv(desc.p1Conv, inCfg, outCfg, useFP16), + g1Conv(desc.g1Conv, inCfg, outCfg, useFP16), + g1BN(desc.g1BN, desc.g1Activation.activation, useFP16), + gpoolToBiasMul(desc.gpoolToBiasMul, useFP16), + p1BN(desc.p1BN, desc.p1Activation.activation, useFP16), + p2Conv(desc.p2Conv, inCfg, outCfg, useFP16), + gpoolToPassMul(desc.gpoolToPassMul, useFP16), + gpoolToPassBias(desc.modelVersion >= 15 + ? std::optional(std::in_place, desc.gpoolToPassBias, useFP16) + : std::nullopt), + passActivationType(desc.modelVersion >= 15 ? desc.passActivation.activation : 0), + gpoolToPassMul2(desc.modelVersion >= 15 + ? std::optional(std::in_place, desc.gpoolToPassMul2, useFP16) + : std::nullopt) + {} + + std::pair apply( + const mx::array& trunk, + const mx::array& mask, + const mx::array& maskSum, + bool useMask + ) const { + // Policy conv + mx::array p1Out = p1Conv.apply(trunk); + + // Global pooling path + mx::array g1Out = g1Conv.apply(trunk); + g1Out = g1BN.apply(g1Out, mask, useMask); + mx::array pooled = applyGlobalPooling(g1Out, mask, maskSum, useMask); + std::vector squeezeAxes = {1, 2}; + mx::array pooledFlat = mx::squeeze(pooled, squeezeAxes); + + // Add bias from global pooling + mx::array bias = gpoolToBiasMul.apply(pooledFlat); + mx::Shape biasShape = {static_cast(bias.shape()[0]), 1, 1, static_cast(bias.shape()[1])}; + bias = mx::reshape(bias, biasShape); + p1Out = p1Out + bias; + + p1Out = p1BN.apply(p1Out, mask, useMask); + + // Final policy conv + mx::array policy = p2Conv.apply(p1Out); + + // Pass policy: pre-v15 is a single matmul (pooled -> output). v15+ is a + // two-layer MLP (pooled -> hidden, + bias, activation, hidden -> output). + // Mirrors the v15+ branch of PolicyHeadDesc::PolicyHeadDesc in desc.cpp + // and Metal's policyHeadDescToSwift in metalbackend.cpp. + mx::array policyPass = gpoolToPassMul.apply(pooledFlat); + if(modelVersion >= 15) { + policyPass = gpoolToPassBias->apply(policyPass); + policyPass = applyActivation(policyPass, passActivationType); + policyPass = gpoolToPassMul2->apply(policyPass); + } + + return {policyPass, policy}; + } +}; + +// Value Head +struct ValueHead { + const string name; + const int modelVersion; + const ConvLayer v1Conv; + const BatchNormLayer v1BN; + const MatMulLayer v2Mul; + const MatBiasLayer v2Bias; + const int v2Activation; + const MatMulLayer v3Mul; + const MatBiasLayer v3Bias; + const MatMulLayer sv3Mul; + const MatBiasLayer sv3Bias; + const ConvLayer vOwnershipConv; + + ValueHead() = delete; + ValueHead(const ValueHead&) = delete; + ValueHead& operator=(const ValueHead&) = delete; + + ValueHead(const ValueHeadDesc& desc, + const MLXWinograd::InputTransform& inCfg, + const MLXWinograd::OutputUntransform& outCfg, + bool useFP16 = false) + : name(desc.name), + modelVersion(desc.modelVersion), + v1Conv(desc.v1Conv, inCfg, outCfg, useFP16), + v1BN(desc.v1BN, desc.v1Activation.activation, useFP16), + v2Mul(desc.v2Mul, useFP16), + v2Bias(desc.v2Bias, useFP16), + v2Activation(desc.v2Activation.activation), + v3Mul(desc.v3Mul, useFP16), + v3Bias(desc.v3Bias, useFP16), + sv3Mul(desc.sv3Mul, useFP16), + sv3Bias(desc.sv3Bias, useFP16), + vOwnershipConv(desc.vOwnershipConv, inCfg, outCfg, useFP16) + {} + + std::tuple apply( + const mx::array& trunk, + const mx::array& mask, + const mx::array& maskSum, + bool useMask + ) const { + mx::array v1Out = v1Conv.apply(trunk); + v1Out = v1BN.apply(v1Out, mask, useMask); + + // Value head pooling (only uses maskSum, not mask) + mx::array v1Mean = applyValueHeadPooling(v1Out, maskSum); + std::vector squeezeAxes = {1, 2}; + mx::array v1MeanFlat = mx::squeeze(v1Mean, squeezeAxes); + + // Fuse matmul + bias with addmm for better performance + mx::array v2Out = matmulBias(v1MeanFlat, v2Mul.weights, v2Bias.bias); + v2Out = applyActivation(v2Out, v2Activation); + + mx::array value = matmulBias(v2Out, v3Mul.weights, v3Bias.bias); + mx::array scoreValue = matmulBias(v2Out, sv3Mul.weights, sv3Bias.bias); + + mx::array ownership = vOwnershipConv.apply(v1Out); + + return {value, scoreValue, ownership}; + } +}; + +// Model +struct Model { + const string name; + const int modelVersion; + const int numInputChannels; + const int numInputGlobalChannels; + const int numInputMetaChannels; + const int numPolicyChannels; + // Pass-policy output width. For v15+ models the pass head is two-layer: + // gpoolToPassMul (input -> hidden) -> bias -> activation -> gpoolToPassMul2 + // (hidden -> output). The actual final output width — and the per-row stride + // extractOutputs in metalbackend.swift uses for its writes + // (batchIndex * numPolicyChannels) — is gpoolToPassMul2.outChannels, which + // PolicyHeadDesc::PolicyHeadDesc in desc.cpp validates equals + // numPolicyChannels. Pre-v15 models have a single matmul (gpoolToPassMul: + // input -> output) and the output width is gpoolToPassMul.outChannels = + // numPolicyChannels (also validated in PolicyHeadDesc::PolicyHeadDesc). + // Using gpoolToPassMul.outChannels for v15+ was the prior bug: it is the + // hidden width, not the output width, and rows >= 1 in batched ANE reads + // landed on uninitialized memory. + const int numPolicyPassChannels; + const int numValueChannels; + const int numScoreValueChannels; + const int numOwnershipChannels; + const bool useFP16; + + const Trunk trunk; + const PolicyHead policyHead; + const ValueHead valueHead; + + Model() = delete; + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + + // nnX/nnY (board dims) are needed by transformer attention blocks to + // precompute RoPE cos/sin tables, which are position-dependent. The MLX + // compiled graph is keyed by nnXLen/nnYLen so a Model built for one board + // size is never reused for another. + Model(const ModelDesc& desc, const MLXWinogradTuneParams& tuneParams, int nnX, int nnY, bool useFP16_ = false) + : name(desc.name), + modelVersion(desc.modelVersion), + numInputChannels(desc.numInputChannels), + numInputGlobalChannels(desc.numInputGlobalChannels), + numInputMetaChannels(desc.numInputMetaChannels), + numPolicyChannels(desc.numPolicyChannels), + numPolicyPassChannels(desc.modelVersion >= 15 + ? desc.policyHead.gpoolToPassMul2.outChannels + : desc.policyHead.gpoolToPassMul.outChannels), + numValueChannels(desc.numValueChannels), + numScoreValueChannels(desc.numScoreValueChannels), + numOwnershipChannels(desc.numOwnershipChannels), + useFP16(useFP16_), + trunk(desc.trunk, tuneParams.inputTransform, tuneParams.outputUntransform, nnX, nnY, useFP16_), + policyHead(desc.policyHead, tuneParams.inputTransform, tuneParams.outputUntransform, useFP16_), + valueHead(desc.valueHead, tuneParams.inputTransform, tuneParams.outputUntransform, useFP16_) + {} + + // Apply model inference with mx::array inputs directly (for compiled execution) + // inputs: [input, inputGlobal, mask, maskSum] or [input, inputGlobal, mask, maskSum, inputMeta] + // outputs: [policy, policyPass, value, scoreValue, ownership] + std::vector applyArrays( + const std::vector& inputs, + bool useMask + ) const { + // Convert inputs to compute dtype if FP16 is enabled + mx::array input = toComputeDtype(inputs[0], useFP16); + mx::array inputGlobalArr = toComputeDtype(inputs[1], useFP16); + mx::array mask = toComputeDtype(inputs[2], useFP16); + // maskSum stays FP32 - small scalar, negligible impact + const mx::array& maskSum = inputs[3]; + unique_ptr inputMeta; + if(inputs.size() > 4) { + inputMeta = make_unique(toComputeDtype(inputs[4], useFP16)); + } + const mx::array* inputMetaPtr = inputMeta.get(); + + // Apply trunk + mx::array trunkOut = trunk.apply(input, inputGlobalArr, inputMetaPtr, mask, maskSum, useMask); + + // Apply policy head + auto [policyPass, policy] = policyHead.apply(trunkOut, mask, maskSum, useMask); + + // Apply value head + auto [value, scoreValue, ownership] = valueHead.apply(trunkOut, mask, maskSum, useMask); + + // Convert outputs back to FP32 for interface compatibility + if(useFP16) { + policy = mx::astype(policy, mx::float32); + policyPass = mx::astype(policyPass, mx::float32); + value = mx::astype(value, mx::float32); + scoreValue = mx::astype(scoreValue, mx::float32); + ownership = mx::astype(ownership, mx::float32); + } + + return {policy, policyPass, value, scoreValue, ownership}; + } + + // Create a compiled inference function for the given configuration + // hasMeta is used as part of the cache key but not needed in the function itself + CompiledInferenceFunc createCompiledFunc(bool useMask, bool /*hasMeta*/) const { + // Create lambda that captures this model + auto inferenceFunc = [this, useMask](const std::vector& inputs) -> std::vector { + return this->applyArrays(inputs, useMask); + }; + + // Wrap in std::function and compile + std::function(const std::vector&)> func = inferenceFunc; + return mx::compile(func, /*shapeless=*/false); + } + + // Apply model using a pre-compiled inference function + void applyCompiled( + const CompiledInferenceFunc& compiledFunc, + const float* inputSpatial, + const float* inputGlobal, + const float* inputMeta, + int batchSize, + int nnXLen, + int nnYLen, + bool requireExactNNLen, + float* policyOut, + float* policyPassOut, + float* valueOut, + float* scoreValueOut, + float* ownershipOut + ) const { + std::vector outputs; + { + // Serialize graph construction + command encoding only. async_eval + // encodes all GPU work synchronously on this thread before returning; + // the completion wait and result readback below run outside the lock + // so another server thread can encode its batch while this one waits. + // See mlxGpuEvalMutex for the full rationale. + std::lock_guard gpuLock(mlxGpuEvalMutex); + + // Create input tensors - NHWC format + mx::Shape inputShape = {batchSize, nnYLen, nnXLen, numInputChannels}; + mx::array input = mx::array(inputSpatial, inputShape, mx::float32); + mx::Shape globalShape = {batchSize, numInputGlobalChannels}; + mx::array inputGlobalArr = mx::array(inputGlobal, globalShape, mx::float32); + + // Extract mask from first channel of input + mx::Shape sliceStart = {0, 0, 0, 0}; + mx::Shape sliceEnd = {batchSize, nnYLen, nnXLen, 1}; + mx::array mask = mx::slice(input, sliceStart, sliceEnd); + + // Compute mask sum + std::vector sumAxes = {1, 2}; + mx::array maskSum = requireExactNNLen + ? mx::full({batchSize, 1, 1, 1}, static_cast(nnXLen * nnYLen)) + : mx::sum(mask, sumAxes, /*keepdims=*/true); + + // Build input vector for compiled function + std::vector inputs = {input, inputGlobalArr, mask, maskSum}; + + // Add metadata if present + if(numInputMetaChannels > 0 && inputMeta != nullptr) { + mx::Shape metaShape = {batchSize, numInputMetaChannels}; + inputs.push_back(mx::array(inputMeta, metaShape, mx::float32)); + } + + // Call compiled function and encode the GPU work + outputs = compiledFunc(inputs); + mx::async_eval(outputs); + } + + // Wait for GPU completion. array::wait() blocks on the array's shared + // completion event without touching the command stream, so it is safe + // (and intended) to run while another thread holds mlxGpuEvalMutex to + // encode the next batch. + for(mx::array& out : outputs) + out.wait(); + + // Extract results - outputs are [policy, policyPass, value, scoreValue, ownership] + mx::array& policy = outputs[0]; + mx::array& policyPass = outputs[1]; + mx::array& value = outputs[2]; + mx::array& scoreValue = outputs[3]; + mx::array& ownership = outputs[4]; + + // Copy results to output buffers. Cast the leading operand to size_t so the + // byte-count product is computed in size_t throughout: the operands are all + // int, and on a large board (USE_BIGGER_BOARDS_EXPENSIVE) with a big batch + // the int sub-product could otherwise overflow before the sizeof promotion. + memcpy(policyOut, policy.data(), (size_t)batchSize * numPolicyChannels * nnXLen * nnYLen * sizeof(float)); + memcpy(policyPassOut, policyPass.data(), (size_t)batchSize * numPolicyPassChannels * sizeof(float)); + memcpy(valueOut, value.data(), (size_t)batchSize * numValueChannels * sizeof(float)); + memcpy(scoreValueOut, scoreValue.data(), (size_t)batchSize * numScoreValueChannels * sizeof(float)); + memcpy(ownershipOut, ownership.data(), (size_t)batchSize * numOwnershipChannels * nnXLen * nnYLen * sizeof(float)); + } +}; + +// Forward declaration needed by the helpers below (struct is defined in the +// "ComputeContext and ComputeHandle" section that follows). +struct ComputeContext; + +//------------------------------------------------------------------------------ +// CoreML/ANE compute handle helpers - mirrors convertAndCreateCoreMLOnlyHandle +// in metalbackend.cpp +//------------------------------------------------------------------------------ + +// Note: KataGoSwift::MetalComputeContext is the Swift-side context type. Its +// name is misleading in this file (MLX, not Metal) but we reuse it as-is per +// the design decision to leave KataGoSwift unchanged. It carries only +// (nnXLen, nnYLen, useFP16). + +// Helper: convert model and create CoreML-only compute handle (for mux ANE thread) +static swift::Optional convertAndCreateCoreMLOnlyHandleMLX( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int serverThreadIdx +); + +// Helper: create CoreML-only handle when gpuIdx == MLX_MUX_ANE. +// Returns Optional::none() for the GPU path. Emits the same FP16-only-ANE +// warning Metal emits when useFP16=false is combined with the ANE mux. +static swift::Optional createCoreMLOnlyHandleIfNeededMLX( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int gpuIdx, + int serverThreadIdx +); + +// ComputeContext and ComputeHandle ------------------------------------------------------------------------------------ + +// Session-scoped, cross-context Winograd-tune memo. The main net and the human +// SL net are separate NNEvaluators with separate ComputeContexts (gtp.cpp builds +// both), yet on iOS MLX/GPU both are b18c384 with identical 3x3-conv shapes, so +// their optimal transform geometry is the same — they should tune ONCE per +// engine session, not twice. A per-context memo can't span the two contexts; +// this process-global one does. It is scoped to a session by being cleared when +// the last ComputeContext is freed (a session == the window in which any context +// is alive), so a forced re-tune in a later session runs afresh instead of +// reusing a stale entry. Sequential context creation in gtp.cpp means the +// check/tune/store below needs no lock held across the (long) tune itself. +// +// All state (the map, its mutex, and the live-context refcount that drives the +// session-scoped clear) lives in this one unit so the locking obligation and the +// "clear on last release" invariant are defined in one place rather than spread +// across createComputeContext / freeComputeContext / the ComputeHandle ctor. +struct WinogradTuneMemo { + // Returns the memoized params for this shape key, or nullopt on a miss. + std::optional tryGet(const std::string& key) { + std::lock_guard lk(mutex); + auto it = memo.find(key); + if(it == memo.end()) + return std::nullopt; + return it->second; + } + void put(const std::string& key, const MLXWinogradTuneParams& params) { + std::lock_guard lk(mutex); + memo[key] = params; + } + // retain/release track how many ComputeContexts are alive. The memo is a + // session-scoped cache: when the last context goes away the session is over, + // so drop everything (a later forced re-tune must not reuse this session's + // results). + void retain() { + std::lock_guard lk(mutex); + liveContexts++; + } + void release() { + std::lock_guard lk(mutex); + if(--liveContexts <= 0) { + liveContexts = 0; + memo.clear(); + } + } +private: + std::mutex mutex; + std::map memo; + int liveContexts = 0; +}; +static WinogradTuneMemo g_winoTuneMemo; + +struct ComputeContext { + const int nnXLen; + const int nnYLen; + const enabled_t useFP16Mode; + std::string homeDataDirOverride; + Logger* logger; + + std::mutex cachedModelsMutex; + std::map> cachedModels; + std::map cachedModelsRefCount; + + // True only when EVERY configured device index is MLX_MUX_ANE (set in + // createComputeContext). The ANE path re-reads the model from disk in + // convertModelToTemp and otherwise reads only scalar dims, so when no thread + // will ever build the MLX/GPU model we can free the in-memory FP32 weights via + // ModelDesc::releaseWeights(). Must stay false if any thread uses MLX_MUX_GPU, + // which would read the freed weights. + bool aneOnly = false; + + // MLX/GPU Winograd autotuner controls, plumbed from the app via + // -override-config (mlxTunerFull / mlxReTune), read in createComputeContext. + // tunerFull=true selects the wide grid (thorough but much slower); reTune=true + // forces a fresh tune that ignores and overwrites the cached file. Consumed by + // the GPU ComputeHandle ctor's loadOrAutoTune call; ignored on the ANE path. + bool tunerFull = false; + bool tunerReTune = false; + + ComputeContext() = delete; + ComputeContext(const ComputeContext&) = delete; + ComputeContext& operator=(const ComputeContext&) = delete; + + ComputeContext(int nnX, int nnY, enabled_t fp16Mode, + const std::string& homeDataDirOverride_, Logger* logger_) + : nnXLen(nnX), + nnYLen(nnY), + useFP16Mode(fp16Mode), + homeDataDirOverride(homeDataDirOverride_), + logger(logger_), + cachedModelsMutex(), + cachedModels(), + cachedModelsRefCount() + {} + + ~ComputeContext() { + assert(cachedModels.size() == 0); + } +}; + +// Resolve the Winograd transform geometry for a GPU-path model load: honor the +// cross-context session memo (g_winoTuneMemo) first, then fall back to +// loadOrAutoTune (disk cache or a fresh sweep), storing the result in the memo. +// Returns default-constructed (identity) transforms when Winograd or the +// winotuner is disabled. Reads only ComputeContext + LoadedModel state and +// touches no ComputeHandle, so it lives as a free function the ctor calls. +static MLXWinogradTuneParams resolveTuneParams( + ComputeContext* context, const LoadedModel& loadedModel, bool useFP16) { + MLXWinogradTuneParams tuneParams; + if(!(mlxWinogradEnabled() && mlxWinotunerEnabled())) + return tuneParams; + + // Shape diagnostic: print the model's 3x3 conv shape distribution before + // calling the tuner so the log carries this signal on every load, including + // cache-hit runs where loadOrAutoTune short-circuits. + if(context->logger != NULL) { + context->logger->write( + MLXWinogradTuner::formatConv3x3Distribution(loadedModel.modelDesc)); + } + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = loadedModel.modelDesc.trunk.trunkNumChannels; + mi.modelVersion = loadedModel.modelDesc.modelVersion; + auto [inHist, outHist] = + MLXWinogradTuner::buildConv3x3Histograms(loadedModel.modelDesc); + mi.conv3x3InputHistogram = std::move(inHist); + mi.conv3x3OutputHistogram = std::move(outHist); + + // Cross-context shape memo (see g_winoTuneMemo): if a same-shape GPU + // handle already tuned this session, reuse its result and skip the sweep + // entirely. This is what keeps the main + human b18c384 nets at a single + // tune instead of two — halving model-load tuning time at zero quality + // cost (identical shape ⇒ identical optimal geometry). + // Include a signature of the 3x3-conv distribution: that histogram (not just + // trunkNumChannels) is what actually drives the tuned geometry via + // planShapeRotation, so two nets with the same trunk width but different conv + // shapes must NOT share a tune. modelVersion is deliberately omitted — same + // shape, different version (e.g. b18c384 main v14 + human v15) should share. + // The histograms are std::map-derived, hence deterministically sorted. + std::string histSig; + for(const auto& p : mi.conv3x3InputHistogram) + histSig += std::to_string(p.first) + ":" + std::to_string(p.second) + ","; + histSig += "/"; + for(const auto& p : mi.conv3x3OutputHistogram) + histSig += std::to_string(p.first) + ":" + std::to_string(p.second) + ","; + const std::string shapeKey = + std::to_string(mi.trunkNumChannels) + + "_" + std::to_string(context->nnXLen) + + "x" + std::to_string(context->nnYLen) + + (useFP16 ? "_fp16" : "_fp32") + + (context->tunerFull ? "_full" : "_fast") + + "_" + histSig; + if(auto memoized = g_winoTuneMemo.tryGet(shapeKey)) { + if(context->logger != NULL) + context->logger->write("Reusing MLX Winograd tuning for shape " + shapeKey + + " (already tuned this session)"); + return *memoized; + } + + tuneParams = MLXWinogradTuner::loadOrAutoTune( + /*tunerFile=*/"", + context->homeDataDirOverride, + MLXWinogradTuner::detectGpuName(), + context->nnXLen, context->nnYLen, + // Tuner times the Winograd input/output transform kernels at this + // batch size only (the matmul stage is untuned). Probed re-tuning + // at 8/16/32/64: the winning configs do differ per batch size, but + // end-to-end throughput stayed flat within ~1.5% run-to-run noise. + // OpenCL's tuner pins a single batch size too. Not worth + // parameterizing. + /*batchSize=*/8, + mi, + context->logger, + // full / reTune come from the app's MLX/GPU tuning UI via + // -override-config (mlxTunerFull / mlxReTune), read into the + // ComputeContext in createComputeContext. Defaults are false/false: + // load a valid cache, or on a miss tune the coarse "fast" grid once. + // full=true selects the wide grid (cached under a distinct "_full" + // file); reTune=true forces a fresh tune that overwrites the current + // mode's cache. `./katago tuner` (optionally -full) remains a separate + // always-overwrite entry point. + /*full=*/context->tunerFull, + /*reTune=*/context->tunerReTune, + /*useFP16=*/useFP16); + g_winoTuneMemo.put(shapeKey, tuneParams); + return tuneParams; +} + +struct ComputeHandle { + ComputeContext* context; + bool inputsUseNHWC; + bool requireExactNNLen; + bool useFP16; + int gpuIdx; + std::string modelCacheKey; // assigned in ctor body after loadOrAutoTune + std::shared_ptr model; + const int modelVersion; + + // ModelDesc fields cached on both paths so getOutput does not have to + // dereference `model` (which is nullptr on the ANE path). Populated in + // the constructor body for both MLX_MUX_GPU and MLX_MUX_ANE. + int numInputChannels; + int numPolicyChannels; + int numPolicyPassChannels; + int numValueChannels; + int numScoreValueChannels; + int numOwnershipChannels; + + // Compiled function cache - keyed by (batchSize, nnXLen, nnYLen, useMask, hasMeta, useFP16). + // Populated only on the MLX/GPU path; the ANE path uses coremlOnlyHandle instead. + mutable std::mutex compiledFuncsMutex; + mutable std::map compiledFuncs; + + // CoreML-only handle (Swift). Populated iff gpuIdx == MLX_MUX_ANE; otherwise none(). + // Exactly one of {model populated (MLX/GPU path) OR coremlOnlyHandle has value (ANE path)}. + swift::Optional coremlOnlyHandle; + + ComputeHandle() = delete; + ComputeHandle(const ComputeHandle&) = delete; + ComputeHandle& operator=(const ComputeHandle&) = delete; + + static std::string makeCacheKey(const LoadedModel& loadedModel, + const MLXWinogradTuneParams& tuneParams, + bool useFP16) { + return loadedModel.modelDesc.name + "-" + loadedModel.modelDesc.sha256 + + (useFP16 ? "-fp16" : "-fp32") + + (mlxWinogradEnabled() ? "-wg" : "-nowg") + + "-it" + std::to_string(tuneParams.inputTransform.tg0) + + "x" + std::to_string(tuneParams.inputTransform.tg1) + + "x" + std::to_string(tuneParams.inputTransform.wpt) + + "x" + std::to_string(tuneParams.inputTransform.vw) + + "g" + std::to_string((int)tuneParams.inputTransform.gridOrder) + + "-ou" + std::to_string(tuneParams.outputUntransform.tg0) + + "x" + std::to_string(tuneParams.outputUntransform.tg1) + + "x" + std::to_string(tuneParams.outputUntransform.wpt); + } + + ComputeHandle(ComputeContext* ctx, + const LoadedModel& loadedModel, + bool iNHWC, + bool requireExactNNLen_, + bool useFP16_, + int gpuIdx_, + int maxBatchSize, + int serverThreadIdx) + : context(ctx), + inputsUseNHWC(iNHWC), + requireExactNNLen(requireExactNNLen_), + useFP16(useFP16_), + gpuIdx(gpuIdx_), + modelCacheKey(), + model(nullptr), + modelVersion(loadedModel.modelDesc.modelVersion), + compiledFuncsMutex(), + compiledFuncs(), + coremlOnlyHandle(createCoreMLOnlyHandleIfNeededMLX( + ctx, &loadedModel, requireExactNNLen_, maxBatchSize, gpuIdx_, serverThreadIdx)) + { + // Cache ModelDesc fields used by both paths in getOutput. + numInputChannels = loadedModel.modelDesc.numInputChannels; + numPolicyChannels = loadedModel.modelDesc.numPolicyChannels; + // See Model::numPolicyPassChannels comment for the v15+ two-layer pass head + // rationale: the per-row stride must match the *final* pass output width + // (gpoolToPassMul2.outChannels for v15+, gpoolToPassMul.outChannels otherwise), + // not the hidden width. + numPolicyPassChannels = + loadedModel.modelDesc.modelVersion >= 15 + ? loadedModel.modelDesc.policyHead.gpoolToPassMul2.outChannels + : loadedModel.modelDesc.policyHead.gpoolToPassMul.outChannels; + numValueChannels = loadedModel.modelDesc.numValueChannels; + numScoreValueChannels = loadedModel.modelDesc.numScoreValueChannels; + numOwnershipChannels = loadedModel.modelDesc.numOwnershipChannels; + + if(gpuIdx_ == MLX_MUX_ANE) { + // ANE path: MLX inference state is intentionally left uninitialized. + checkExactlyOnePath(); + return; + } + + // GPU path: resolve the Winograd geometry (session memo → disk cache → + // fresh sweep; see resolveTuneParams), then build/cache the Model below. + MLXWinogradTuneParams tuneParams = resolveTuneParams(context, loadedModel, useFP16_); + + modelCacheKey = makeCacheKey(loadedModel, tuneParams, useFP16_); + + std::lock_guard lock(context->cachedModelsMutex); + if(context->cachedModels.find(modelCacheKey) == context->cachedModels.end()) { + context->cachedModels[modelCacheKey] = + std::make_shared(loadedModel.modelDesc, tuneParams, + context->nnXLen, context->nnYLen, useFP16_); + } + model = context->cachedModels[modelCacheKey]; + context->cachedModelsRefCount[modelCacheKey] += 1; + + // GPU path invariant check. + checkExactlyOnePath(); + } + + // Invariant: exactly one inference path is live — the MLX/GPU `model` OR the + // CoreML/ANE `coremlOnlyHandle`, never both or neither. Reads members set in + // the init list (gpuIdx, coremlOnlyHandle) and `model` as assigned so far, so + // it is valid at the end of either ctor path. + void checkExactlyOnePath() const { + bool hasMLX = (model != nullptr); + bool hasCoreML = static_cast(coremlOnlyHandle); + if(hasMLX == hasCoreML) { + throw runtime_error( + string("MLX backend: Logic error - expected exactly one compute handle, got ") + + (hasMLX && hasCoreML ? "both" : "neither") + + " (gpuIdx=" + to_string(gpuIdx) + ")"); + } + } + + ~ComputeHandle() { + // Only the GPU path populated the cachedModels map; ANE path's destructor + // is a no-op for the MLX-side state. Swift ARC releases coremlOnlyHandle + // automatically when the swift::Optional member is destroyed. + if(gpuIdx == MLX_MUX_ANE) + return; + + std::lock_guard lock(context->cachedModelsMutex); + context->cachedModelsRefCount[modelCacheKey] -= 1; + assert(context->cachedModelsRefCount[modelCacheKey] >= 0); + if(context->cachedModelsRefCount[modelCacheKey] == 0) { + context->cachedModelsRefCount.erase(modelCacheKey); + context->cachedModels.erase(modelCacheKey); + } + } + + // Get or create compiled inference function for the given configuration. + // GPU path only — must not be called on an ANE-mux handle. + const CompiledInferenceFunc& getCompiledFunc(int batchSize, int nnXLen, int nnYLen, bool useMask, bool hasMeta) const { + assert(gpuIdx == MLX_MUX_GPU); + CompileCacheKey key = std::make_tuple(batchSize, nnXLen, nnYLen, useMask, hasMeta, useFP16); + + std::lock_guard lock(compiledFuncsMutex); + auto it = compiledFuncs.find(key); + if(it != compiledFuncs.end()) { + return it->second; + } + + // Create and cache compiled function + compiledFuncs[key] = model->createCompiledFunc(useMask, hasMeta); + return compiledFuncs[key]; + } +}; + +// InputBuffers -------------------------------------------------------------------------------------------------------- + +struct InputBuffers { + int maxBatchSize; + + size_t singleInputElts; + size_t singleInputGlobalElts; + size_t singleInputMetaElts; + + size_t singlePolicyPassResultElts; + size_t singlePolicyResultElts; + size_t singleValueResultElts; + size_t singleScoreValueResultElts; + size_t singleOwnershipResultElts; + size_t singleMaskElts; + + std::vector spatialInput; + std::vector globalInput; + std::vector metaInput; + std::vector userInputMaskBuffer; + // NCHW staging buffer for the ANE/CoreML dispatch path. The Swift + // CoreMLComputeHandle.apply() allocates MLMultiArray with shape + // [1, C, H, W] and memcpys each row's bytes, so it strictly requires + // NCHW. spatialInput stays NHWC for the MLX/GPU path; rows are + // transposed into this buffer inside getOutput before dispatch. The + // MLX/GPU path never reads this buffer. + std::vector userInputBufferNCHW; + std::vector policyResults; + std::vector policyPassResults; + std::vector valueResults; + std::vector scoreValueResults; + std::vector ownershipResults; + + InputBuffers(const LoadedModel* loadedModel, int maxBatchSz, int nnXLen, int nnYLen) { + const ModelDesc& m = loadedModel->modelDesc; + + maxBatchSize = maxBatchSz; + singleInputElts = m.numInputChannels * nnXLen * nnYLen; + singleInputGlobalElts = m.numInputGlobalChannels; + singleInputMetaElts = m.numInputMetaChannels; + + // See Model::numPolicyPassChannels comment: pass output width is + // gpoolToPassMul2.outChannels for v15+, gpoolToPassMul.outChannels otherwise. + // Must match ComputeHandle::numPolicyPassChannels (assertion in getOutput). + singlePolicyPassResultElts = (size_t)( + m.modelVersion >= 15 + ? m.policyHead.gpoolToPassMul2.outChannels + : m.policyHead.gpoolToPassMul.outChannels); + singlePolicyResultElts = (size_t)(m.numPolicyChannels * nnXLen * nnYLen); + singleValueResultElts = (size_t)m.numValueChannels; + singleScoreValueResultElts = (size_t)m.numScoreValueChannels; + singleOwnershipResultElts = (size_t)m.numOwnershipChannels * nnXLen * nnYLen; + singleMaskElts = (size_t)nnXLen * nnYLen; + + assert(NNModelVersion::getNumSpatialFeatures(m.modelVersion) == m.numInputChannels); + assert(NNModelVersion::getNumGlobalFeatures(m.modelVersion) == m.numInputGlobalChannels); + if(m.numInputMetaChannels > 0) { + assert(SGFMetadata::METADATA_INPUT_NUM_CHANNELS == m.numInputMetaChannels); + } + + spatialInput.resize(m.numInputChannels * nnXLen * nnYLen * maxBatchSize); + globalInput.resize(m.numInputGlobalChannels * maxBatchSize); + if(m.numInputMetaChannels > 0) + metaInput.resize(m.numInputMetaChannels * maxBatchSize); + else + metaInput.resize(1); + + policyResults.resize(singlePolicyResultElts * maxBatchSize); + policyPassResults.resize(singlePolicyPassResultElts * maxBatchSize); + valueResults.resize(singleValueResultElts * maxBatchSize); + scoreValueResults.resize(singleScoreValueResultElts * maxBatchSize); + ownershipResults.resize(singleOwnershipResultElts * maxBatchSize); + userInputMaskBuffer.resize(singleMaskElts * maxBatchSize); + userInputBufferNCHW.resize(singleInputElts * maxBatchSize); + } + + ~InputBuffers() {} + + InputBuffers() = delete; + InputBuffers(const InputBuffers&) = delete; + InputBuffers& operator=(const InputBuffers&) = delete; +}; + +InputBuffers* NeuralNet::createInputBuffers(const LoadedModel* loadedModel, int maxBatchSize, int nnXLen, int nnYLen) { + return new InputBuffers(loadedModel, maxBatchSize, nnXLen, nnYLen); +} + +void NeuralNet::freeInputBuffers(InputBuffers* inputBuffers) { + delete inputBuffers; +} + +// NeuralNet Interface ------------------------------------------------------------------------------------------------- + +void NeuralNet::globalInitialize() { + // MLX initializes automatically +} + +void NeuralNet::globalCleanup() { + // MLX cleans up automatically +} + +// Helper implementations (forward-declared before ComputeContext; defined here +// after ComputeContext and LoadedModel are both fully visible). + +static swift::Optional convertAndCreateCoreMLOnlyHandleMLX( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int serverThreadIdx +) { + int nnXLen = context->nnXLen; + int nnYLen = context->nnYLen; + bool useFP16 = (context->useFP16Mode != enabled_t::False); + bool optimizeMask = requireExactNNLen; + + // Convert model to CoreML format in temp directory + string coremlModelPath = CoreMLConversion::convertModelToTemp( + loadedModel->modelPath, + nnXLen, + nnYLen, + useFP16, + optimizeMask, + maxBatchSize, + serverThreadIdx + ); + + // ANE-only context: the converter has just re-read the model from disk + // (modelPath) into CoreML form, and nothing afterward reads the in-memory + // weight arrays — the ComputeHandle ctor takes the MLX_MUX_ANE early-return + // before building any MLX/GPU model, and only scalar dims are read later + // (numInputChannels, ..., which releaseWeights() preserves). Runs under + // computeHandleMutex (held by createComputeHandle), so it is not racy; + // releaseWeights() is idempotent across the per-thread ANE handles. + if(context->aneOnly) { + const_cast(loadedModel)->modelDesc.releaseWeights(); + } + + // The Swift createCoreMLComputeHandle entry point expects a + // MetalComputeContext. Construct one on-the-fly from MLX's context values. + auto swiftContext = KataGoSwift::createMetalComputeContext( + static_cast(nnXLen), + static_cast(nnYLen), + useFP16); + + // Create CoreML-only compute handle (CPU+ANE) — same Swift entry point Metal uses. + return KataGoSwift::createCoreMLComputeHandle( + swift::String(coremlModelPath), + serverThreadIdx, + requireExactNNLen, + loadedModel->modelDesc.numInputChannels, + loadedModel->modelDesc.numInputGlobalChannels, + loadedModel->modelDesc.numInputMetaChannels, + loadedModel->modelDesc.numPolicyChannels, + loadedModel->modelDesc.numValueChannels, + loadedModel->modelDesc.numScoreValueChannels, + loadedModel->modelDesc.numOwnershipChannels, + swiftContext + ); +} + +static swift::Optional createCoreMLOnlyHandleIfNeededMLX( + ComputeContext* context, + const LoadedModel* loadedModel, + bool requireExactNNLen, + int maxBatchSize, + int gpuIdx, + int serverThreadIdx +) { + if(gpuIdx != MLX_MUX_ANE) { + return swift::Optional::none(); + } + + if(context->useFP16Mode == enabled_t::False) { + // Honor the user's explicit FP32 request even on an ANE thread: the ANE + // is FP16-only, so CoreML falls back to CPU. Result is correct (and + // deterministic) FP32 CoreML inference, just much slower than GPU. + cerr << "MLX backend " << serverThreadIdx << ": Note: ANE thread with mlxUseFP16=false: " + << "the ANE is FP16-only, so CoreML will run this thread on CPU (FP32). " + << "This is significantly slower than the GPU path; if you wanted ANE acceleration, " + << "remove mlxUseFP16=false." << endl; + } + + cerr << "MLX backend " << serverThreadIdx << ": Mux ANE mode - using CoreML (CPU+ANE)" << endl; + return convertAndCreateCoreMLOnlyHandleMLX(context, loadedModel, requireExactNNLen, maxBatchSize, serverThreadIdx); +} + +ComputeContext* NeuralNet::createComputeContext( + const std::vector& gpuIdxs, + Logger* logger, + int nnXLen, + int nnYLen, + const string& homeDataDirOverride, + enabled_t useFP16Mode, + const LoadedModel* loadedModel, + ConfigParser& cfg +) { + (void)loadedModel; + + // aneOnly drives the ANE-path weight release in convertAndCreateCoreMLOnlyHandleMLX. + // INVARIANT: gpuIdxs must be the complete, deduplicated set of device indices any + // thread will use under this context. Free the in-memory weights only when every one + // is MLX_MUX_ANE; if a thread later used an MLX_MUX_GPU index not represented here it + // would read freed weights. + bool aneOnly = !gpuIdxs.empty(); + for(int idx : gpuIdxs) { + if(idx != MLX_MUX_ANE) { aneOnly = false; break; } + } + + // MLX requires NHWC inputs; this is enforced per-handle via inputsUseNHWC in + // createComputeHandle (the old context-level useNHWCMode param was removed + // upstream when createComputeContext was consolidated onto ConfigParser). + ComputeContext* context = new ComputeContext(nnXLen, nnYLen, useFP16Mode, homeDataDirOverride, logger); + context->aneOnly = aneOnly; + // Track live contexts so the cross-context tune memo can be cleared when this + // engine session ends (see g_winoTuneMemo). + g_winoTuneMemo.retain(); + // MLX/GPU Winograd autotuner controls (app sets these via -override-config). + // Read here so the GPU ComputeHandle ctor can honor them; harmless on the ANE + // path, which returns before the tuner. Calling getBool marks the keys "used" + // so ConfigParser doesn't flag them as unused overrides. + context->tunerFull = cfg.contains("mlxTunerFull") ? cfg.getBool("mlxTunerFull") : false; + context->tunerReTune = cfg.contains("mlxReTune") ? cfg.getBool("mlxReTune") : false; + return context; +} + +void NeuralNet::freeComputeContext(ComputeContext* computeContext) { + delete computeContext; + // When the last context of this engine session goes away, release() drops the + // cross-context tune memo so the next session (e.g. a forced re-tune) starts + // fresh rather than reusing this session's results. + g_winoTuneMemo.release(); +} + +ComputeHandle* NeuralNet::createComputeHandle( + ComputeContext* context, + const LoadedModel* loadedModel, + Logger* logger, + int maxBatchSize, + bool requireExactNNLen, + bool inputsUseNHWC, + int gpuIdxForThisThread, + int serverThreadIdx +) { + // Auto resolves to fp16. The original acceptance gate (MLX-fp16 paired-t + // beat both Metal-fp16 and MLX-fp32 with non-overlapping CIs, and + // testgpuerror accuracy exit=0) is preserved in the traceability commit. + // Users who need bit-for-bit fp32 reproducibility set `mlxUseFP16 = false` + // explicitly. + bool useFP16 = (context->useFP16Mode != enabled_t::False); + + // gpuIdx == -1 is the "no preference" sentinel from upstream; map to default GPU. + int gpuIdx = (gpuIdxForThisThread == -1) ? MLX_MUX_GPU : gpuIdxForThisThread; + if(gpuIdx != MLX_MUX_GPU && gpuIdx != MLX_MUX_ANE) { + throw StringError( + "MLX backend: Invalid mlxDeviceToUseThread value " + std::to_string(gpuIdx) + + " for server thread " + std::to_string(serverThreadIdx) + + ". The MLX backend only supports " + std::to_string(MLX_MUX_GPU) + + " (GPU via MLX) or " + std::to_string(MLX_MUX_ANE) + + " (ANE via CoreML)."); + } + + if(logger != NULL) { + logger->write("MLX backend thread " + Global::intToString(serverThreadIdx) + ": Model version " + Global::intToString(loadedModel->modelDesc.modelVersion)); + logger->write("MLX backend thread " + Global::intToString(serverThreadIdx) + ": Model name: " + loadedModel->modelDesc.name); + logger->write("MLX backend thread " + Global::intToString(serverThreadIdx) + ": FP16 = " + (useFP16 ? "true" : "false")); + logger->write("MLX backend thread " + Global::intToString(serverThreadIdx) + ": gpuIdx = " + Global::intToString(gpuIdx)); + } + + if(!inputsUseNHWC) + throw StringError("MLX backend: inputsUseNHWC = false unsupported"); + + // Serialize handle construction: see computeHandleMutex declaration above. + std::lock_guard lock(computeHandleMutex); + return new ComputeHandle(context, *loadedModel, inputsUseNHWC, requireExactNNLen, useFP16, + gpuIdx, maxBatchSize, serverThreadIdx); +} + +void NeuralNet::freeComputeHandle(ComputeHandle* gpuHandle) { + delete gpuHandle; +} + +bool NeuralNet::isUsingFP16(const ComputeHandle* handle) { + return handle->useFP16; +} + +void NeuralNet::getOutput( + ComputeHandle* computeHandle, + InputBuffers* inputBuffers, + int numBatchEltsFilled, + NNResultBuf** inputBufs, + vector& outputs +) { + assert(numBatchEltsFilled <= inputBuffers->maxBatchSize); + assert(numBatchEltsFilled > 0); + const int batchSize = numBatchEltsFilled; + const int nnXLen = computeHandle->context->nnXLen; + const int nnYLen = computeHandle->context->nnYLen; + const int modelVersion = computeHandle->modelVersion; + + const int numSpatialFeatures = NNModelVersion::getNumSpatialFeatures(modelVersion); + const int numGlobalFeatures = NNModelVersion::getNumGlobalFeatures(modelVersion); + const int numMetaFeatures = inputBuffers->singleInputMetaElts; + assert(numSpatialFeatures == computeHandle->numInputChannels); + assert(numSpatialFeatures * nnXLen * nnYLen == inputBuffers->singleInputElts); + assert(numGlobalFeatures == inputBuffers->singleInputGlobalElts); + const int numPolicyChannels = computeHandle->numPolicyChannels; + + // Copy input data to buffers + for(int nIdx = 0; nIdx < batchSize; nIdx++) { + float* rowSpatialInput = inputBuffers->spatialInput.data() + (inputBuffers->singleInputElts * nIdx); + float* rowGlobalInput = inputBuffers->globalInput.data() + (inputBuffers->singleInputGlobalElts * nIdx); + float* rowMetaInput = inputBuffers->metaInput.data() + (inputBuffers->singleInputMetaElts * nIdx); + + const float* rowGlobal = inputBufs[nIdx]->rowGlobalBuf.data(); + const float* rowSpatial = inputBufs[nIdx]->rowSpatialBuf.data(); + const float* rowMeta = inputBufs[nIdx]->rowMetaBuf.data(); + const bool hasRowMeta = inputBufs[nIdx]->hasRowMeta; + + std::copy(rowGlobal, rowGlobal + numGlobalFeatures, rowGlobalInput); + + if(numMetaFeatures > 0) { + testAssert(rowMeta != NULL); + testAssert(hasRowMeta); + std::copy(rowMeta, rowMeta + numMetaFeatures, rowMetaInput); + } + else { + testAssert(!hasRowMeta); + } + + SymmetryHelpers::copyInputsWithSymmetry(rowSpatial, rowSpatialInput, 1, nnYLen, nnXLen, numSpatialFeatures, computeHandle->inputsUseNHWC, inputBufs[nIdx]->symmetry); + + // ANE/CoreML path needs an NCHW spatial buffer because the Swift + // CoreMLComputeHandle.apply() allocates MLMultiArray with shape + // [1, C, H, W] and raw memcpys C*H*W floats per row. spatialInput + // is NHWC (required by the MLX/GPU path's mx::array shape), so we + // transpose each row into userInputBufferNCHW here. The validity + // mask (channel 0) sits at the start of the converted row, so it + // collapses to a contiguous memcpy into userInputMaskBuffer. + // + // When the mlpackage was converted with optimize_identity_mask=true + // (i.e., requireExactNNLen=true) the ANE model ignores the mask + // buffer, but populating it unconditionally costs essentially + // nothing (one memcpy of H*W floats) and avoids a silent- + // misprediction footgun when optimize_identity_mask=false. + // + // The MLX/GPU path slices channel 0 itself via mx::slice and does + // not read userInputMaskBuffer or userInputBufferNCHW. + if(computeHandle->coremlOnlyHandle) { + const int C = computeHandle->numInputChannels; + const size_t HW = inputBuffers->singleMaskElts; // nnXLen * nnYLen + float* rowNCHW = inputBuffers->userInputBufferNCHW.data() + + inputBuffers->singleInputElts * nIdx; + const float* rowNHWC = rowSpatialInput; // [H*W, C] + for(int c = 0; c < C; c++) { + float* dstCh = rowNCHW + (size_t)c * HW; + for(size_t hw = 0; hw < HW; hw++) { + dstCh[hw] = rowNHWC[hw * C + c]; + } + } + float* dstMask = inputBuffers->userInputMaskBuffer.data() + + inputBuffers->singleMaskElts * nIdx; + std::memcpy(dstMask, rowNCHW, HW * sizeof(float)); + } + } + + // Dispatch to appropriate path based on mux mode. + if(computeHandle->coremlOnlyHandle) { + // ANE path: dispatch through the Swift CoreMLComputeHandle. Swift + // creates MLMultiArray(shape: [1, C, H, W]) per row and memcpys + // C*H*W floats — strict NCHW. We pass userInputBufferNCHW (rows + // transposed from NHWC in the loop above) instead of spatialInput. + // The mask is the contiguous H*W float prefix of each NCHW row, + // already lifted into userInputMaskBuffer above. The mlpackage + // ignores the mask buffer iff it was converted with + // optimize_identity_mask=true. + computeHandle->coremlOnlyHandle.get().apply( + inputBuffers->userInputBufferNCHW.data(), + inputBuffers->globalInput.data(), + inputBuffers->metaInput.data(), // always non-null (resized to at least 1 in InputBuffers ctor) + inputBuffers->userInputMaskBuffer.data(), + inputBuffers->policyResults.data(), + inputBuffers->policyPassResults.data(), + inputBuffers->valueResults.data(), + inputBuffers->scoreValueResults.data(), + inputBuffers->ownershipResults.data(), + batchSize); + } else { + // GPU path: run the MLX compiled function exactly as before. + const bool useMask = !computeHandle->requireExactNNLen; + const bool hasMeta = (numMetaFeatures > 0); + const CompiledInferenceFunc& compiledFunc = computeHandle->getCompiledFunc(batchSize, nnXLen, nnYLen, useMask, hasMeta); + + computeHandle->model->applyCompiled( + compiledFunc, + inputBuffers->spatialInput.data(), + inputBuffers->globalInput.data(), + (numMetaFeatures > 0 ? inputBuffers->metaInput.data() : nullptr), + batchSize, + nnXLen, + nnYLen, + computeHandle->requireExactNNLen, + inputBuffers->policyResults.data(), + inputBuffers->policyPassResults.data(), + inputBuffers->valueResults.data(), + inputBuffers->scoreValueResults.data(), + inputBuffers->ownershipResults.data() + ); + } + + assert(inputBuffers->singlePolicyPassResultElts == (size_t)computeHandle->numPolicyPassChannels); + assert(inputBuffers->singlePolicyResultElts == numPolicyChannels * nnXLen * nnYLen); + assert(outputs.size() == batchSize); + + float policyProbsTmp[NNPos::MAX_NN_POLICY_SIZE]; + + float* policyData = inputBuffers->policyResults.data(); + float* policyPassData = inputBuffers->policyPassResults.data(); + float* valueData = inputBuffers->valueResults.data(); + float* scoreValueData = inputBuffers->scoreValueResults.data(); + float* ownershipData = inputBuffers->ownershipResults.data(); + + for(int row = 0; row < batchSize; row++) { + NNOutput* output = outputs[row]; + assert(output->nnXLen == nnXLen); + assert(output->nnYLen == nnYLen); + float policyOptimism = (float)inputBufs[row]->policyOptimism; + + const float* policyPassSrcBuf = policyPassData + row * computeHandle->numPolicyPassChannels; + const float* policySrcBuf = policyData + row * numPolicyChannels * nnXLen * nnYLen; + float* policyProbs = output->policyProbs; + + // Handle policy optimism (version >= 12). The optimism mix uses + // channel 0 (p) and channel 1 (pOpt) of the policy output; v16+ + // channels 2-3 are ignored here, matching MetalProcess::processOptimism + // in metalbackend.cpp. + // + // MLX/GPU writes NHWC: channels are interleaved per spatial position. + // CoreML/ANE writes NCHW (MLMultiArray shape [1, C, H, W], contiguous + // memcpy in metalbackend.swift copyMultiArray): channel 0 occupies the + // first HW floats, channel 1 the next HW, etc. Stride differs per path. + if(numPolicyChannels == 2 || (numPolicyChannels == 4 && modelVersion >= 16)) { + const int HW = nnXLen * nnYLen; + const bool isNCHW = (bool)computeHandle->coremlOnlyHandle; + const int strideI = isNCHW ? 1 : numPolicyChannels; + const int strideOpt = isNCHW ? HW : 1; + for(int i = 0; i < HW; i++) { + float p = policySrcBuf[i * strideI]; + float pOpt = policySrcBuf[i * strideI + strideOpt]; + policyProbsTmp[i] = p + (pOpt - p) * policyOptimism; + } + SymmetryHelpers::copyOutputsWithSymmetry(policyProbsTmp, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[nnXLen * nnYLen] = policyPassSrcBuf[0] + (policyPassSrcBuf[1] - policyPassSrcBuf[0]) * policyOptimism; + } + else { + // Fail loud (not a release-elided assert) on an unsupported channel count: + // anything other than 1 here would silently mis-stride the single-channel read. + if(numPolicyChannels != 1) + throw StringError("MLX backend: unsupported numPolicyChannels=" + Global::intToString(numPolicyChannels)); + SymmetryHelpers::copyOutputsWithSymmetry(policySrcBuf, policyProbs, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + policyProbs[inputBuffers->singlePolicyResultElts] = policyPassSrcBuf[0]; + } + + int numValueChannels = computeHandle->numValueChannels; + assert(numValueChannels == 3); + output->whiteWinProb = valueData[row * numValueChannels]; + output->whiteLossProb = valueData[row * numValueChannels + 1]; + output->whiteNoResultProb = valueData[row * numValueChannels + 2]; + + if(output->whiteOwnerMap != NULL) { + const float* ownershipSrcBuf = ownershipData + row * nnXLen * nnYLen; + assert(computeHandle->numOwnershipChannels == 1); + SymmetryHelpers::copyOutputsWithSymmetry(ownershipSrcBuf, output->whiteOwnerMap, 1, nnYLen, nnXLen, inputBufs[row]->symmetry); + } + + if(modelVersion >= 9) { + int numScoreValueChannels = computeHandle->numScoreValueChannels; + assert(numScoreValueChannels == 6); + output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; + output->whiteLead = scoreValueData[row * numScoreValueChannels + 2]; + output->varTimeLeft = scoreValueData[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = scoreValueData[row * numScoreValueChannels + 4]; + output->shorttermScoreError = scoreValueData[row * numScoreValueChannels + 5]; + } + else if(modelVersion >= 8) { + int numScoreValueChannels = computeHandle->numScoreValueChannels; + assert(numScoreValueChannels == 4); + output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; + output->whiteLead = scoreValueData[row * numScoreValueChannels + 2]; + output->varTimeLeft = scoreValueData[row * numScoreValueChannels + 3]; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(modelVersion >= 4) { + int numScoreValueChannels = computeHandle->numScoreValueChannels; + assert(numScoreValueChannels == 2); + output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = scoreValueData[row * numScoreValueChannels + 1]; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else if(modelVersion >= 3) { + int numScoreValueChannels = computeHandle->numScoreValueChannels; + assert(numScoreValueChannels == 1); + output->whiteScoreMean = scoreValueData[row * numScoreValueChannels]; + output->whiteScoreMeanSq = output->whiteScoreMean * output->whiteScoreMean; + output->whiteLead = output->whiteScoreMean; + output->varTimeLeft = 0; + output->shorttermWinlossError = 0; + output->shorttermScoreError = 0; + } + else { + ASSERT_UNREACHABLE; + } + } +} + +void NeuralNet::printDevices() { + cout << "MLX Backend (Apple Silicon)" << endl; + cout << "Default device: " << mx::default_device() << endl; +} + +// FOR TESTING --------------------------------------------------------------------------------------------------------- + +bool NeuralNet::testEvaluateConv( + const ConvLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + vector& outputBuffer +) { + // Run MLX-specific aux tests (Winograd kernel + tuner) exactly once per + // process, on the first invocation of testEvaluateConv. This is the + // MLX-side hook reachable from Tests::runNNLayerTests through + // testConvLayer, allowing testnn.cpp to stay backend-agnostic. + // The flag is set BEFORE the calls so a propagating exception does not + // cause the aux tests to re-run on subsequent conv configs. + static bool ranMLXAuxTests = false; + if(!ranMLXAuxTests) { + ranMLXAuxTests = true; + runMLXWinogradTests(); + runMLXWinotunerTests(); + } + + if(!useNHWC) { + return false; // MLX only supports NHWC + } + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->outChannels; + outputBuffer.resize(numOutputFloats); + + MLXWinograd::InputTransform defaultInCfg; + MLXWinograd::OutputUntransform defaultOutCfg; + ConvLayer layer(*desc, defaultInCfg, defaultOutCfg, useFP16); + mx::Shape inputShape = {batchSize, nnYLen, nnXLen, desc->inChannels}; + mx::array input = mx::array(inputBuffer.data(), inputShape, mx::float32); + mx::array computeInput = toComputeDtype(input, useFP16); + mx::array output = layer.apply(computeInput); + if(useFP16) output = mx::astype(output, mx::float32); + mx::eval(output); + + memcpy(outputBuffer.data(), output.data(), numOutputFloats * sizeof(float)); + return true; +} + +bool NeuralNet::testEvaluateBatchNorm( + const BatchNormLayerDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + if(!useNHWC) { + return false; + } + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->numChannels; + outputBuffer.resize(numOutputFloats); + + BatchNormLayer layer(*desc, ACTIVATION_IDENTITY, useFP16); + mx::Shape inputShape = {batchSize, nnYLen, nnXLen, desc->numChannels}; + mx::Shape maskShape = {batchSize, nnYLen, nnXLen, 1}; + mx::array input = mx::array(inputBuffer.data(), inputShape, mx::float32); + mx::array mask = mx::array(maskBuffer.data(), maskShape, mx::float32); + mx::array computeInput = toComputeDtype(input, useFP16); + mx::array computeMask = toComputeDtype(mask, useFP16); + mx::array output = layer.apply(computeInput, computeMask, /*useMask=*/true); + if(useFP16) output = mx::astype(output, mx::float32); + mx::eval(output); + + memcpy(outputBuffer.data(), output.data(), numOutputFloats * sizeof(float)); + return true; +} + +bool NeuralNet::testEvaluateResidualBlock( + const ResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + if(!useNHWC) { + return false; + } + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; + outputBuffer.resize(numOutputFloats); + + MLXWinograd::InputTransform defaultInCfg; + MLXWinograd::OutputUntransform defaultOutCfg; + ResidualBlock block(*desc, defaultInCfg, defaultOutCfg, useFP16); + mx::Shape inputShape = {batchSize, nnYLen, nnXLen, desc->preBN.numChannels}; + mx::Shape maskShape = {batchSize, nnYLen, nnXLen, 1}; + mx::array input = mx::array(inputBuffer.data(), inputShape, mx::float32); + mx::array mask = mx::array(maskBuffer.data(), maskShape, mx::float32); + mx::array computeInput = toComputeDtype(input, useFP16); + mx::array computeMask = toComputeDtype(mask, useFP16); + mx::array output = block.apply(computeInput, computeMask, /*useMask=*/true); + if(useFP16) output = mx::astype(output, mx::float32); + mx::eval(output); + + memcpy(outputBuffer.data(), output.data(), numOutputFloats * sizeof(float)); + return true; +} + +bool NeuralNet::testEvaluateGlobalPoolingResidualBlock( + const GlobalPoolingResidualBlockDesc* desc, + int batchSize, + int nnXLen, + int nnYLen, + bool useFP16, + bool useNHWC, + const vector& inputBuffer, + const vector& maskBuffer, + vector& outputBuffer +) { + if(!useNHWC) { + return false; + } + + size_t numOutputFloats = (size_t)batchSize * nnXLen * nnYLen * desc->preBN.numChannels; + outputBuffer.resize(numOutputFloats); + + MLXWinograd::InputTransform defaultInCfg; + MLXWinograd::OutputUntransform defaultOutCfg; + GlobalPoolingResidualBlock block(*desc, defaultInCfg, defaultOutCfg, useFP16); + mx::Shape inputShape = {batchSize, nnYLen, nnXLen, desc->preBN.numChannels}; + mx::Shape maskShape = {batchSize, nnYLen, nnXLen, 1}; + mx::array input = mx::array(inputBuffer.data(), inputShape, mx::float32); + mx::array mask = mx::array(maskBuffer.data(), maskShape, mx::float32); + mx::array computeInput = toComputeDtype(input, useFP16); + mx::array computeMask = toComputeDtype(mask, useFP16); + std::vector sumAxes = {1, 2}; + // maskSum stays FP32 for precision + mx::array maskSum = mx::sum(mask, sumAxes, /*keepdims=*/true); + mx::array output = block.apply(computeInput, computeMask, maskSum, /*useMask=*/true); + if(useFP16) output = mx::astype(output, mx::float32); + mx::eval(output); + + memcpy(outputBuffer.data(), output.data(), numOutputFloats * sizeof(float)); + return true; +} + +// Directly-asserting unit test for BatchNormLayer fp16 mode. +// Declared here because BatchNormLayer is not in any public header. +// Called from runMLXWinogradTests() in mlxtests.cpp. +void runMLXBatchNormFP16Test() { + namespace mxc = mx; // reuse the file-scope `mx` alias + using std::cout; + using std::endl; + + int N=1,H=5,W=5,C=4; + std::vector mean(C, 0.0f), variance(C, 1.0f), scale(C, 1.0f), bias(C, 0.0f); + BatchNormLayerDesc bnDesc; + bnDesc.name = "bnFP16Test"; + bnDesc.numChannels = C; + bnDesc.epsilon = 1e-5f; + bnDesc.mean = mean; + bnDesc.variance = variance; + bnDesc.scale = scale; + bnDesc.bias = bias; + BatchNormLayer bn(bnDesc, ACTIVATION_IDENTITY, /*useFP16=*/true); + + // mergedScale/mergedBias must be fp32 even in fp16 mode. + testAssert(bn.mergedScale.dtype() == mxc::float32); + testAssert(bn.mergedBias.dtype() == mxc::float32); + + // apply() must return fp16 when useFP16=true. + std::vector inV((size_t)N*H*W*C, 0.5f); + std::vector maskV((size_t)N*H*W*1, 1.0f); + mxc::array inArrF32(inV.data(), {N,H,W,C}, mxc::float32); + mxc::array inArr = mxc::astype(inArrF32, mxc::float16); + mxc::array maskArrF32(maskV.data(), {N,H,W,1}, mxc::float32); + mxc::array maskArr = mxc::astype(maskArrF32, mxc::float16); + mxc::array out = bn.apply(inArr, maskArr, /*useMask=*/true); + mxc::eval(out); + testAssert(out.dtype() == mxc::float16); + cout << " BatchNormLayer fp16: mergedScale/Bias fp32, output fp16 OK" << endl; +} + +// Directly-asserting unit test for ConvLayer fp16 Winograd path. +// Declared here because ConvLayer is not in any public header. +// Called from runMLXWinogradTests() in mlxtests.cpp. +void runMLXConvLayerFP16WinogradTest() { + namespace mxc = mx; // reuse the file-scope `mx` alias + using std::cout; + using std::endl; + + int N=1,H=19,W=19,Cin=8,Cout=16; + std::mt19937 grng(779); + std::uniform_real_distribution gdist(-1.f,1.f); + std::vector in((size_t)N*H*W*Cin); for(auto&x:in)x=gdist(grng); + std::vector w((size_t)Cout*Cin*9); for(auto&x:w)x=gdist(grng); + auto refv = MLXWinograd::cpuConv2d3x3(in,N,H,W,Cin,w,Cout); + + ConvLayerDesc convDesc; + convDesc.name = "convFP16WinogradTest"; + convDesc.convYSize = 3; + convDesc.convXSize = 3; + convDesc.inChannels = Cin; + convDesc.outChannels = Cout; + convDesc.dilationY = 1; + convDesc.dilationX = 1; + convDesc.weights = w; + + MLXWinograd::InputTransform inCfg; + MLXWinograd::OutputUntransform outCfg; + ConvLayer conv(convDesc, inCfg, outCfg, /*useFP16=*/true); + testAssert(conv.useWinograd); // fp16 still picks Winograd + + mxc::array inArrF32(in.data(),{N,H,W,Cin},mxc::float32); + mxc::array inArr = mxc::astype(inArrF32, mxc::float16); + mxc::array o = conv.apply(inArr); + mxc::eval(o); + testAssert(o.dtype() == mxc::float16); + mxc::array oF32 = mxc::astype(o, mxc::float32); + mxc::eval(oF32); + const float* od = oF32.data(); + double maxErr=0.0; + for(size_t i=0;i apply() == trunk exactly. +void runMLXTransformerLayerFP16Test() { + namespace mxc = mx; // reuse the file-scope `mx` alias + using std::cout; + using std::endl; + + std::mt19937 rng(20260607u); + std::uniform_real_distribution dist(-0.3f, 0.3f); // keep fp16 well-conditioned + auto fillRand = [&](std::vector& v){ for(auto& x : v) x = dist(rng); }; + + // ---- (1)+(2) RMSNorm correctness + dtype ---- + { + const int N=1,H=3,W=3,C=8; + const float eps=1e-5f; + std::vector weightV(C); fillRand(weightV); + // Bias the weights to ~1.0 so the normalize is not degenerate. + for(auto& x : weightV) x += 1.0f; + + TransformerRMSNormDesc desc; + desc.name = "rmsnormFP16Test"; + desc.numChannels = C; + desc.epsilon = eps; + desc.weight = weightV; + + std::vector inV((size_t)N*H*W*C); fillRand(inV); + std::vector maskV((size_t)N*H*W*1, 1.0f); + + mxc::array inArrF32(inV.data(), {N,H,W,C}, mxc::float32); + mxc::array maskF32(maskV.data(), {N,H,W,1}, mxc::float32); + + // fp32 layer; useMask=false. + TransformerRMSNormLayer rmsF32(desc, /*useFP16=*/false); + testAssert(rmsF32.weight.dtype() == mxc::float32); + mxc::array outF32 = rmsF32.apply(inArrF32, maskF32, /*useMask=*/false); + mxc::eval(outF32); + const float* op = outF32.data(); + + // CPU reference: per position, ms = mean_c(x^2); r = 1/sqrt(ms+eps); + // out_c = x_c * r * weight_c. + double maxErr=0.0; + for(int pos=0; pos lnW(inChannels); fillRand(lnW); + for(auto& x : lnW) x += 1.0f; + desc.preLN.weight = lnW; + // Projections. + makeMatMulDesc(desc.qProj, "qProj", inChannels, numHeads*qHeadDim, false); + makeMatMulDesc(desc.kProj, "kProj", inChannels, numKVHeads*qHeadDim, false); + makeMatMulDesc(desc.vProj, "vProj", inChannels, numKVHeads*vHeadDim, false); + makeMatMulDesc(desc.outProj,"outProj",numHeads*vHeadDim, inChannels, zeroOutProj); + }; + + // ---- (3) Attention dtype preservation (THE §1 guard) + fp16-vs-fp32 ---- + { + const int N=1, nnX=3, nnY=3; // seq=9 + const int inChannels=8, numHeads=2, numKVHeads=1, qHeadDim=4, vHeadDim=4; + + // Fixed weights (one RNG draw) reused by the fp16 and fp32 blocks so the two + // can be compared. Build the desc once for fp16 and an identical one for fp32. + auto runVariant = [&](bool useRope){ + // descF16 and descF32 must get BIT-IDENTICAL weights so the only difference + // between the two blocks is the compute dtype. buildOne therefore seeds a + // FRESH RNG from the same base seed on every call (a single shared RNG would + // hand the second desc a different draw, i.e. a different random network -- + // the comparison would then measure network divergence, not fp16 error). + const unsigned baseSeed = 424242u + (useRope?1u:0u); + auto buildOne = [&](TransformerAttentionDesc& desc){ + std::mt19937 localRng(baseSeed); + std::uniform_real_distribution ld(-0.3f, 0.3f); + auto lfill = [&](std::vector& v){ for(auto& x:v) x=ld(localRng); }; + desc.name = "attnFP16Test"; + desc.numHeads = numHeads; desc.numKVHeads = numKVHeads; + desc.qHeadDim = qHeadDim; desc.vHeadDim = vHeadDim; + desc.useRope = useRope; desc.learnableRope = false; desc.ropeTheta = 10000.0f; + desc.preLN.name = "attnPreLN"; desc.preLN.numChannels = inChannels; desc.preLN.epsilon = 1e-5f; + std::vector lnW(inChannels); lfill(lnW); for(auto& x:lnW) x+=1.0f; desc.preLN.weight = lnW; + auto mk = [&](MatMulLayerDesc& d, const std::string& nm, int inC, int outC){ + d.name=nm; d.inChannels=inC; d.outChannels=outC; + d.weights.assign((size_t)inC*outC,0.0f); lfill(d.weights); + }; + mk(desc.qProj, "qProj", inChannels, numHeads*qHeadDim); + mk(desc.kProj, "kProj", inChannels, numKVHeads*qHeadDim); + mk(desc.vProj, "vProj", inChannels, numKVHeads*vHeadDim); + mk(desc.outProj,"outProj",numHeads*vHeadDim, inChannels); + }; + + TransformerAttentionDesc descF16; buildOne(descF16); + TransformerAttentionDesc descF32; buildOne(descF32); // same RNG seed => same weights + + std::vector trunkV((size_t)N*nnY*nnX*inChannels); + { std::mt19937 tr(99u); std::uniform_real_distribution td(-0.3f,0.3f); + for(auto& x:trunkV) x=td(tr); } + std::vector maskV((size_t)N*nnY*nnX*1, 1.0f); + // For the useMask=true case, mask out two positions. + maskV[2]=0.0f; maskV[5]=0.0f; + + mxc::array trunkF32(trunkV.data(), {N,nnY,nnX,inChannels}, mxc::float32); + mxc::array maskF32(maskV.data(), {N,nnY,nnX,1}, mxc::float32); + mxc::array trunkF16 = mxc::astype(trunkF32, mxc::float16); + mxc::array maskF16 = mxc::astype(maskF32, mxc::float16); + + TransformerAttentionBlock blkF16(descF16, nnX, nnY, /*useFP16=*/true); + TransformerAttentionBlock blkF32(descF32, nnX, nnY, /*useFP16=*/false); + + for(bool useMask : {true, false}) { + mxc::array oF16 = blkF16.apply(trunkF16, maskF16, useMask); + mxc::eval(oF16); + // THE guard: output dtype must stay fp16 (a stray fp32 scalar promotes it). + testAssert(oF16.dtype() == mxc::float16); + mxc::array oF16to32 = mxc::astype(oF16, mxc::float32); + mxc::eval(oF16to32); + const float* p16 = oF16to32.data(); + bool finite=true; for(size_t i=0;i(); + double maxErr=0.0, maxAbs=0.0; + for(size_t i=0;i +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +namespace mx = mlx::core; + +// Defined in mlxbackend.cpp — they need the file-local BatchNormLayer / +// ConvLayer classes, so they cannot move here. +void runMLXBatchNormFP16Test(); +void runMLXConvLayerFP16WinogradTest(); +void runMLXTransformerLayerFP16Test(); + +// Fused-epilogue parity: winogradConv2d(Epilogue::bnAct/residual) must equal +// the unfused decomposition (conv -> BN+act, conv -> +residual) computed with +// plain MLX ops. fp32: exact (Tier 1). fp16: bit-exact for residual; bnAct +// allows a tiny tolerance for in-kernel vs JIT transcendental last bits (Tier 2). +static void runMLXWinogradEpilogueTests() { + namespace mxc = mlx::core; + using namespace MLXWinograd; + cout << "Running MLX Winograd epilogue-fusion tests" << endl; + const int N=2,H=19,W=19,C=24; // square 3x3, Cin==Cout for residual + std::mt19937 rng(9091); + std::uniform_real_distribution d(-1.f,1.f); + vector in((size_t)N*H*W*C); for(auto&x:in) x=d(rng); + vector w((size_t)C*C*9); for(auto&x:w) x=d(rng); + vector scale(C), bias(C); for(int i=0;i() so their + // buffers outlive the read (a temporary astype result would be freed at + // the end of the full expression, leaving the pointer dangling in fp16). + mxc::array uF=mxc::astype(unfused,mxc::float32); mxc::eval(uF); + auto* a=uF.data(); + mxc::array fF=mxc::astype(fused,mxc::float32); mxc::eval(fF); + auto* b=fF.data(); + double mx=0; for(size_t i=0;i& in,int N,int H,int W,int Cin, + const vector& w,int Cout){ + vector out((size_t)N*H*W*Cout,0.f); + for(int n=0;n=0&&iy=0&&ix dist(-1.f,1.f); + for(auto dims : vector>{{1,5,5,3,4},{2,19,19,8,16},{1,7,13,4,4}}){ + int N=dims[0],H=dims[1],W=dims[2],Cin=dims[3],Cout=dims[4]; + vector in((size_t)N*H*W*Cin); for(auto&x:in)x=dist(rng); + vector w((size_t)Cout*Cin*9); for(auto&x:w)x=dist(rng); + auto ref = direct(in,N,H,W,Cin,w,Cout); + auto got = MLXWinograd::cpuConv2d3x3(in,N,H,W,Cin,w,Cout); + double maxErr=0.0; + for(size_t i=0;i"< gdist(-1.f,1.f); + vector in((size_t)N*H*W*Cin); for(auto&x:in)x=gdist(grng); + vector w((size_t)Cout*Cin*9); for(auto&x:w)x=gdist(grng); + auto refv = MLXWinograd::cpuConv2d3x3(in,N,H,W,Cin,w,Cout); + mxc::array inArr(in.data(),{N,H,W,Cin},mxc::float32); + auto Uw = MLXWinograd::makeWinogradWeights(w,Cout,Cin); + MLXWinograd::InputTransform inCfg; + MLXWinograd::OutputUntransform outCfg; + mxc::array o = MLXWinograd::winogradConv2d(inArr,Uw,Cout,inCfg,outCfg); + mxc::eval(o); + const float* od = o.data(); + double maxErr=0.0; + for(size_t i=0;i Ntiles = 2*10*10 = 200. + { + using namespace MLXWinograd; + namespace mx = mlx::core; + std::vector in_data((size_t)2*19*19*64); + std::mt19937 rng(0x1234u); + std::uniform_real_distribution fdist(-1.0f, 1.0f); + for(auto& x : in_data) x = fdist(rng); + mx::array inp(in_data.data(), {2, 19, 19, 64}, mx::float32); + + std::vector w_data((size_t)64*64*9, 1.0f); + mx::array Uw = makeWinogradWeights(w_data, 64, 64, false); + + auto runWith = [&](int wpt_in, int wpt_out) { + InputTransform inCfg; inCfg.wpt = wpt_in; + OutputUntransform outCfg; outCfg.wpt = wpt_out; + mx::array out = winogradConv2d(inp, Uw, 64, inCfg, outCfg, false); + mx::eval(out); + return out; + }; + + // Vary input WPT, output stays at WPT=1. + mx::array out_w1 = runWith(1, 1); + mx::array out_w4 = runWith(4, 1); + mx::array out_w8 = runWith(8, 1); + // Vary output WPT, input stays at WPT=1. + mx::array out_ow4 = runWith(1, 4); + mx::array out_ow8 = runWith(1, 8); + + // Compare bit-for-bit (no FP-ordering change — only thread loop unroll differs). + const float* p1 = out_w1.data(); + const float* p4 = out_w4.data(); + const float* p8 = out_w8.data(); + const float* po4 = out_ow4.data(); + const float* po8 = out_ow8.data(); + size_t n = (size_t)2 * 19 * 19 * 64; + for(size_t i = 0; i < n; i++) { + testAssert(p1[i] == p4[i]); + testAssert(p1[i] == p8[i]); + testAssert(p1[i] == po4[i]); + testAssert(p1[i] == po8[i]); + } + cout << " MLX Winograd WPT bit-for-bit equivalence (1/4/8) passed" << endl; + } + + // Tail-guard coverage: Ntiles=100 (N=1, H=W=19) is NOT + // divisible by WPT=8, so the last thread along the slow axis has + // tileIdx in {96..103}; iterations 100..103 must hit the break. + { + using namespace MLXWinograd; + namespace mx = mlx::core; + std::vector in_data((size_t)1*19*19*64); + std::mt19937 rng(0xBEEFu); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : in_data) x = dist(rng); + mx::array inp(in_data.data(), {1, 19, 19, 64}, mx::float32); + + std::vector w_data((size_t)64*64*9, 1.0f); + mx::array Uw = makeWinogradWeights(w_data, 64, 64, false); + + auto runWith = [&](int wpt_in, int wpt_out) { + InputTransform inCfg; inCfg.wpt = wpt_in; + OutputUntransform outCfg; outCfg.wpt = wpt_out; + mx::array out = winogradConv2d(inp, Uw, 64, inCfg, outCfg, false); + mx::eval(out); + return out; + }; + + mx::array out_w1 = runWith(1, 1); + mx::array out_w8in = runWith(8, 1); // input WPT=8 with Ntiles%WPT != 0 + mx::array out_w8out = runWith(1, 8); // output WPT=8 with Ntiles%WPT != 0 + + const float* p1 = out_w1.data(); + const float* p8i = out_w8in.data(); + const float* p8o = out_w8out.data(); + size_t n = (size_t)1 * 19 * 19 * 64; + for(size_t i = 0; i < n; i++) { + testAssert(p1[i] == p8i[i]); + testAssert(p1[i] == p8o[i]); + } + cout << " MLX Winograd WPT tail-guard coverage (Ntiles=100, WPT=8) passed" << endl; + } + + // Input VW=1, 2, 4 must produce bit-identical fp16 output (Cfast). C=64 + // is divisible by 4 — VW=4 valid. Output VW is gone (kernel is VW=1 + // monomorphic). + { + using namespace MLXWinograd; + namespace mx = mlx::core; + std::vector in_data((size_t)2*19*19*64); + std::mt19937 rng(0x9ABCu); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : in_data) x = dist(rng); + mx::array inp = mx::astype(mx::array(in_data.data(), {2, 19, 19, 64}, mx::float32), mx::float16); + + std::vector w_data((size_t)64*64*9, 0.5f); + mx::array Uw = makeWinogradWeights(w_data, 64, 64, true); + + auto runWith = [&](int vw_in) { + InputTransform inCfg; inCfg.vw = vw_in; + OutputUntransform outCfg; + mx::array out = winogradConv2d(inp, Uw, 64, inCfg, outCfg, true); + mx::eval(out); + return out; + }; + + mx::array out_v1 = runWith(1); + mx::array out_v2in = runWith(2); + mx::array out_v4in = runWith(4); + + // Cast to fp32 and compare bit-for-bit (no FP-op reordering — only channel + // sequencing differs across input VW, so equality must hold exactly). + mx::array out_v1_fp32 = mx::astype(out_v1, mx::float32); + mx::array out_v2in_fp32 = mx::astype(out_v2in, mx::float32); + mx::array out_v4in_fp32 = mx::astype(out_v4in, mx::float32); + mx::eval(out_v1_fp32, out_v2in_fp32, out_v4in_fp32); + const float* p1 = out_v1_fp32.data(); + const float* p2i = out_v2in_fp32.data(); + const float* p4i = out_v4in_fp32.data(); + size_t n = (size_t)2 * 19 * 19 * 64; + for(size_t i = 0; i < n; i++) { + testAssert(p1[i] == p2i[i]); + testAssert(p1[i] == p4i[i]); + } + cout << " MLX Winograd input-VW bit-for-bit equivalence (1/2/4 fp16, Cfast) passed" << endl; + } + + // Input-stage GridOrder::Cfast and GridOrder::Tfast must produce + // bit-identical fp32 output. They differ only in which thread does which + // (c, tileIdx) pair; the on-disk layout is unchanged. The output kernel + // is Cfast-monomorphic, so only the input gridOrder is varied here. + { + using namespace MLXWinograd; + namespace mx = mlx::core; + std::vector in_data((size_t)2*19*19*64); + std::mt19937 rng(0xDEADu); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : in_data) x = dist(rng); + mx::array inp(in_data.data(), {2, 19, 19, 64}, mx::float32); + + std::vector w_data((size_t)64*64*9); + for(auto& x : w_data) x = dist(rng); + mx::array Uw = makeWinogradWeights(w_data, 64, 64, false); + + auto runWith = [&](GridOrder go_in) { + InputTransform inC; inC.gridOrder = go_in; + OutputUntransform outC; + mx::array out = winogradConv2d(inp, Uw, 64, inC, outC, false); + mx::eval(out); + return out; + }; + + // Input Cfast (baseline). + mx::array out_c = runWith(GridOrder::Cfast); + // Input Tfast — kernel swaps thread mapping, output must match. + mx::array out_t = runWith(GridOrder::Tfast); + + const float* pc = out_c.data(); + const float* pt = out_t.data(); + size_t n = (size_t)2 * 19 * 19 * 64; + for(size_t i = 0; i < n; i++) { + testAssert(pc[i] == pt[i]); + } + std::cout << " MLX Winograd input-stage Cfast vs Tfast bit-for-bit equivalence passed" << std::endl; + } + + // Tail-guard coverage: input Tfast with C=67 (not + // divisible by WPT=8). Last thread group has only 3 channels (67 % 8 = 3); + // the tail-guard `if (c >= C_k) break;` fires for the other 5 iterations. + // We verify input Tfast still matches input Cfast for this shape. + { + using namespace MLXWinograd; + namespace mx = mlx::core; + std::vector in_data((size_t)1*19*19*67); + std::mt19937 rng(0xFEEDu); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : in_data) x = dist(rng); + mx::array inp(in_data.data(), {1, 19, 19, 67}, mx::float32); + + std::vector w_data((size_t)67*67*9); + for(auto& x : w_data) x = dist(rng); + mx::array Uw = makeWinogradWeights(w_data, 67, 67, false); + + auto runWith = [&](GridOrder go, int wpt) { + InputTransform inC; inC.gridOrder = go; inC.wpt = wpt; + OutputUntransform outC; outC.wpt = wpt; + mx::array out = winogradConv2d(inp, Uw, 67, inC, outC, false); + mx::eval(out); + return out; + }; + + mx::array out_cfast = runWith(GridOrder::Cfast, 1); + mx::array out_tfast = runWith(GridOrder::Tfast, 8); // input Tfast with WPT=8, C=67 not divisible. + + const float* pc = out_cfast.data(); + const float* pt = out_tfast.data(); + size_t n = (size_t)1 * 19 * 19 * 67; + for(size_t i = 0; i < n; i++) { + testAssert(pc[i] == pt[i]); + } + std::cout << " MLX Winograd input-stage Tfast tail-guard coverage (C=67, WPT=8) passed" << std::endl; + } + + { + // Output kernel is monomorphic on VW=1, GRID_ORDER=Cfast. + // Run a full conv via winogradConv2d with a deterministic input and weight + // tensor; assert the output is finite and matches a stable reference + // checksum (sum of absolute values to 4 decimal places). This catches: + // - stale Tfast read paths in the output kernel + // - stale VW>1 vector-load paths + // - Std-only weight layout not consistent with kernel reads + namespace mx = mlx::core; + using namespace MLXWinograd; + + const int N = 1, H = 8, W = 8, Cin = 8, Cout = 8; + + // Deterministic input: i*0.01. + std::vector inData(N * H * W * Cin); + for(size_t i = 0; i < inData.size(); i++) inData[i] = (float)i * 0.01f; + mx::array input(inData.data(), {N, H, W, Cin}, mx::float32); + + // Deterministic 3x3 weights: (oc*Cin*9 + ic*9 + k)*0.001. + std::vector wData(Cout * Cin * 9); + for(size_t i = 0; i < wData.size(); i++) wData[i] = (float)i * 0.001f; + // makeWinogradWeights takes raw [Cout, Cin, 3, 3] flattened and produces + // the transformed [16, Cin, Cout] tensor (Std-only). + mx::array U = makeWinogradWeights(wData, Cout, Cin, /*useFP16=*/false); + + // Output config: Std OutputUntransform has tg0/tg1/wpt only. + InputTransform inCfg{}; + inCfg.tg0 = 32; inCfg.tg1 = 1; inCfg.wpt = 1; inCfg.vw = 1; + inCfg.gridOrder = GridOrder::Cfast; + OutputUntransform outCfg{}; + outCfg.tg0 = 16; outCfg.tg1 = 4; outCfg.wpt = 1; + + mx::array out = winogradConv2d(input, U, Cout, inCfg, outCfg); + mx::eval(out); + + // Output shape must be [N, H, W, Cout]. + testAssert(out.shape(0) == N); + testAssert(out.shape(1) == H); + testAssert(out.shape(2) == W); + testAssert(out.shape(3) == Cout); + + // Pull data; assert all finite. + std::vector outData(out.size()); + out.eval(); + std::memcpy(outData.data(), out.data(), outData.size() * sizeof(float)); + for(float v : outData) testAssert(std::isfinite(v)); + + // Stable checksum: sum of absolute values. This is a regression check — + // a change in numerics suggests a kernel-template mismatch (e.g., output + // kernel reads channels via VW>1 path that no longer exists, producing + // UB-flavored garbage). + double sumAbs = 0.0; + for(float v : outData) sumAbs += std::abs(v); + // Recompute this expected value once after the test is first written — + // it captures the deterministic conv result for the inputs above. The + // test passes thereafter as a regression check, not a correctness check. + // Tolerance: 0.5% to absorb minor reordering noise from MLX graph rewrites. + constexpr double expectedSumAbs = 22788.156637847424; // captured 2026-05-21 + testAssert(std::abs(sumAbs - expectedSumAbs) / expectedSumAbs < 0.005); + std::cout << " Output-kernel monomorphic smoke test OK" << std::endl; + } + + runMLXWinogradEpilogueTests(); +} + +void runMLXWinotunerTests() { + cout << "Running MLX Winograd tuner tests" << endl; + + { + // Conv-3x3 distribution formatter — pure-function test. Verifies the + // log-line format directly without any descriptor walk or GPU work. + // Order convention: pairs sorted descending by invocation + // count, ties broken by channel count descending. + + // Case A: two distinct shapes, each appearing once. Tie on count, so + // tie-break by channel count descending: 64 before 32. + { + std::map inputC = {{32, 1}, {64, 1}}; + std::map outputC = {{32, 1}, {64, 1}}; + std::string line = MLXWinogradTuner::formatConv3x3DistributionLine(2, inputC, outputC); + testAssert(line.find("MLX tuner conv3x3 distribution:") != std::string::npos); + testAssert(line.find("total=2") != std::string::npos); + testAssert(line.find("input_c=64:1,32:1") != std::string::npos); + testAssert(line.find("output_c=64:1,32:1") != std::string::npos); + } + + // Case B: asymmetric counts. 384 appears 36 times, 192 once. Sort by + // count descending, so 384 first regardless of channel-count order. + { + std::map inputC = {{384, 36}, {192, 1}}; + std::map outputC = {{384, 37}}; + std::string line = MLXWinogradTuner::formatConv3x3DistributionLine(37, inputC, outputC); + testAssert(line.find("total=37") != std::string::npos); + testAssert(line.find("input_c=384:36,192:1") != std::string::npos); + testAssert(line.find("output_c=384:37") != std::string::npos); + } + + // Case C: empty model — no 3x3 convs. Error handling: print the + // line with explicit "{}" markers; don't suppress. + { + std::map empty; + std::string line = MLXWinogradTuner::formatConv3x3DistributionLine(0, empty, empty); + testAssert(line.find("total=0") != std::string::npos); + testAssert(line.find("input_c={}") != std::string::npos); + testAssert(line.find("output_c={}") != std::string::npos); + } + std::cout << " conv3x3 distribution formatter OK" << std::endl; + } + + { + // planShapeRotation — pure-function tests. Verifies the selection rule + // (top-3, 3% threshold, 3-rep floor, proportional remainder) directly + // without any GPU work. + + // Case A: single shape — entire budget on that shape, weight = 1.0. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting({{192, 72}}); + testAssert(plan.size() == 1); + testAssert(plan[0].channels == 192); + testAssert(plan[0].measureReps == 19); + testAssert(std::abs(plan[0].weight - 1.0) < 1e-9); + } + + // Case B: two shapes both above threshold (b18c384nbt-like, after the + // 22:1 entry has already been dropped by threshold). Expected: + // work = 192*72, 128*5 = 13824, 640; weights 0.956, 0.044; + // round(0.956*19)=18, round(0.044*19)=1; floor bumps 1->3; dominant 18-2=16. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting({{192, 72}, {128, 5}}); + testAssert(plan.size() == 2); + testAssert(plan[0].channels == 192); + testAssert(plan[1].channels == 128); + testAssert(plan[0].measureReps == 16); + testAssert(plan[1].measureReps == 3); + testAssert(plan[0].measureReps + plan[1].measureReps == 19); + testAssert(std::abs(plan[0].weight + plan[1].weight - 1.0) < 1e-9); + testAssert(plan[0].weight > plan[1].weight); + } + + // Case C: minor shape below 3% threshold — dropped entirely, dominant + // absorbs all 19 reps. Histogram: 192:72 (work 13824, 95.5%), 22:1 (work 22, 0.15%). + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting({{192, 72}, {22, 1}}); + testAssert(plan.size() == 1); + testAssert(plan[0].channels == 192); + testAssert(plan[0].measureReps == 19); + testAssert(std::abs(plan[0].weight - 1.0) < 1e-9); + } + + // Case D: four shapes — top-3 cut drops the 4th, then threshold drops + // one more. Input: 384:60, 192:8, 128:5, 64:5. After top-3: drop 64:5. + // work remaining = 23040, 1536, 640; total 25216; 128's share = 2.54% < 3% + // -> drop 128. Final: 384 (93.75%) + 192 (6.25%). reps: round(0.9375*19)=18, + // round(0.0625*19)=1; floor bumps 1->3; dominant 18-2=16. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting( + {{384, 60}, {192, 8}, {128, 5}, {64, 5}}); + testAssert(plan.size() == 2); + testAssert(plan[0].channels == 384); + testAssert(plan[1].channels == 192); + testAssert(plan[0].measureReps == 16); + testAssert(plan[1].measureReps == 3); + } + + // Case E: three shapes all above threshold. Input: 200:10, 100:10, 50:10. + // work = 2000, 1000, 500; total 3500; shares 57.1%, 28.6%, 14.3% (all >3%). + // reps: round(0.571*19)=11, round(0.286*19)=5, round(0.143*19)=3. + // Sum = 19 exactly (no rounding repair needed). All >= floor of 3. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting( + {{200, 10}, {100, 10}, {50, 10}}); + testAssert(plan.size() == 3); + testAssert(plan[0].channels == 200); + testAssert(plan[1].channels == 100); + testAssert(plan[2].channels == 50); + int total = plan[0].measureReps + plan[1].measureReps + plan[2].measureReps; + testAssert(total == 19); + testAssert(plan[2].measureReps >= 3); + testAssert(plan[0].measureReps >= plan[1].measureReps); + testAssert(plan[1].measureReps >= plan[2].measureReps); + } + + // Case F: 2 shapes with equal work and complementary 0.5 shares — + // exercises the rounding-repair branch. Input: 200:1, 100:2 (work + // 200, 200; tied; tie-break by larger C → plan[0]=C=200). Each + // share is 0.5; lround(0.5*19) = lround(9.5) = 10 each (lround + // rounds halves away from zero); pre-repair sum = 20; repair: + // dominant absorbs delta = 19 - 20 = -1; final (9, 10). Both + // measureReps stay ≥ kRepFloor=3 so floor-bump is a no-op. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting( + {{200, 1}, {100, 2}}); + testAssert(plan.size() == 2); + testAssert(plan[0].channels == 200); + testAssert(plan[1].channels == 100); + testAssert(plan[0].measureReps + plan[1].measureReps == 19); + testAssert(plan[0].measureReps == 9); + testAssert(plan[1].measureReps == 10); + testAssert(plan[0].measureReps >= 3); + testAssert(plan[1].measureReps >= 3); + } + + // Case G: coarse rep budget (full=false, the per-model-load path). Single + // dominant shape gets the entire 7-rep coarse budget (vs 19 for full). + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting({{384, 72}}, /*full=*/false); + testAssert(plan.size() == 1); + testAssert(plan[0].channels == 384); + testAssert(plan[0].measureReps == 7); + testAssert(std::abs(plan[0].weight - 1.0) < 1e-9); + } + + // Case H: coarse budget, three shapes above threshold. 200:10, 100:10, + // 50:10 → shares 57.1%, 28.6%, 14.3%; lround(share*7) = 4,2,1; the 2-rep + // coarse floor bumps the trailing 1 up to 2, deficit out of the dominant + // → (3,2,2), Σ=7. Mirrors Case E but on the coarse budget. + { + auto plan = MLXWinogradTuner::planShapeRotationForTesting( + {{200, 10}, {100, 10}, {50, 10}}, /*full=*/false); + testAssert(plan.size() == 3); + int total = plan[0].measureReps + plan[1].measureReps + plan[2].measureReps; + testAssert(total == 7); + testAssert(plan[2].measureReps >= 2); // coarse floor is 2, not 3 + testAssert(plan[0].measureReps >= plan[1].measureReps); + testAssert(plan[1].measureReps >= plan[2].measureReps); + } + + std::cout << " planShapeRotation OK" << std::endl; + } + + { + // buildConv3x3HistogramsFromConvs — pure-function test on the conv + // filter+histogram. Constructs ConvLayerDesc instances directly + // (ConvLayerDesc is default-constructible but has a deleted copy ctor; + // see desc.h), so we build the descriptors in a deque (stable addresses, + // no copies on growth) and pass pointers to the helper. Does not touch + // ModelDesc. + + auto initConv = [](ConvLayerDesc& c, int kY, int kX, int inC, int outC) { + c.convYSize = kY; + c.convXSize = kX; + c.inChannels = inC; + c.outChannels = outC; + }; + + // Four layers: only the two 3x3 layers should contribute. + std::deque storage; + std::vector convs; + storage.emplace_back(); initConv(storage.back(), 1, 1, 10, 10); convs.push_back(&storage.back()); // 1x1 — filtered + storage.emplace_back(); initConv(storage.back(), 3, 3, 20, 30); convs.push_back(&storage.back()); // input_c[20]++, output_c[30]++ + storage.emplace_back(); initConv(storage.back(), 3, 3, 30, 30); convs.push_back(&storage.back()); // input_c[30]++, output_c[30]++ + storage.emplace_back(); initConv(storage.back(), 5, 5, 40, 40); convs.push_back(&storage.back()); // 5x5 — filtered + + auto [inHist, outHist] = + MLXWinogradTuner::buildConv3x3HistogramsFromConvsForTesting(convs); + + // Convert to maps for order-independent comparison. + std::map inMap(inHist.begin(), inHist.end()); + std::map outMap(outHist.begin(), outHist.end()); + + testAssert(inMap.size() == 2); + testAssert(inMap[20] == 1); + testAssert(inMap[30] == 1); + testAssert(inMap.count(10) == 0); // 1x1 didn't leak through + testAssert(inMap.count(40) == 0); // 5x5 didn't leak through + + testAssert(outMap.size() == 1); + testAssert(outMap[30] == 2); + testAssert(outMap.count(10) == 0); + testAssert(outMap.count(40) == 0); + + // Asymmetric 3x3 (e.g. 3x1) must also be filtered — the kernel is + // strictly square-3. + std::deque asymStorage; + std::vector asym; + asymStorage.emplace_back(); initConv(asymStorage.back(), 3, 1, 16, 16); asym.push_back(&asymStorage.back()); + asymStorage.emplace_back(); initConv(asymStorage.back(), 1, 3, 16, 16); asym.push_back(&asymStorage.back()); + asymStorage.emplace_back(); initConv(asymStorage.back(), 3, 3, 16, 16); asym.push_back(&asymStorage.back()); + auto [inA, outA] = + MLXWinogradTuner::buildConv3x3HistogramsFromConvsForTesting(asym); + testAssert(inA.size() == 1 && inA[0].first == 16 && inA[0].second == 1); + testAssert(outA.size() == 1 && outA[0].first == 16 && outA[0].second == 1); + + // Empty input → empty histograms (no assert; this is just the pure + // core. The mlxbackend.cpp call site asserts non-empty after a real + // model walk; mlxbackend.cpp pre-computes the histogram at model + // load and stores it on ModelInfoForTuning so the tuner does not + // re-walk the descriptor). + std::vector empty; + auto [inE, outE] = + MLXWinogradTuner::buildConv3x3HistogramsFromConvsForTesting(empty); + testAssert(inE.empty()); + testAssert(outE.empty()); + + std::cout << " buildConv3x3HistogramsFromConvs OK" << std::endl; + } + + // ---- v3 round-trip: tg0/tg1/wpt/vw/gridOrder (input), tg0/tg1/wpt (output) ---- + { + // v3 roundtrip: write -> load -> compare all 8 fields. Two + // cases for input gridOrder: Cfast and Tfast. (Tfast forces vw=1 per + // isValid invariant.) + using namespace MLXWinograd; + for(auto inGo : {GridOrder::Cfast, GridOrder::Tfast}) { + MLXWinogradTuneParams p; + p.inputTransform.tg0 = 32; + p.inputTransform.tg1 = 1; + p.inputTransform.wpt = 2; + p.inputTransform.vw = (inGo == GridOrder::Cfast) ? 2 : 1; + p.inputTransform.gridOrder = inGo; + p.outputUntransform.tg0 = 32; + p.outputUntransform.tg1 = 8; + p.outputUntransform.wpt = 1; + testAssert(p.isValid()); + + std::string tmpFile = "/tmp/katago_mlx_winotuner_v3_roundtrip_" + std::to_string((int)inGo) + ".txt"; + MLXWinogradTuneParams::save(tmpFile, p); + MLXWinogradTuneParams q = MLXWinogradTuneParams::load(tmpFile); + testAssert(q.inputTransform.tg0 == p.inputTransform.tg0); + testAssert(q.inputTransform.tg1 == p.inputTransform.tg1); + testAssert(q.inputTransform.wpt == p.inputTransform.wpt); + testAssert(q.inputTransform.vw == p.inputTransform.vw); + testAssert(q.inputTransform.gridOrder == p.inputTransform.gridOrder); + testAssert(q.outputUntransform.tg0 == p.outputUntransform.tg0); + testAssert(q.outputUntransform.tg1 == p.outputUntransform.tg1); + testAssert(q.outputUntransform.wpt == p.outputUntransform.wpt); + testAssert(q.isValid()); + std::remove(tmpFile.c_str()); + } + cout << " v3 roundtrip (Cfast + Tfast) OK" << endl; + } + + // dtype-aware cache filenames must coexist in the same directory + // without collision. Verify defaultFileName gains a _fp16/_fp32 suffix. + { + std::string nameF32 = MLXWinogradTuner::defaultFileName( + "AppleSilicon", 19, 19, 384, 13, /*useFP16=*/false); + std::string nameF16 = MLXWinogradTuner::defaultFileName( + "AppleSilicon", 19, 19, 384, 13, /*useFP16=*/true); + testAssert(nameF32 != nameF16); + testAssert(nameF32.find("_fp32") != std::string::npos); + testAssert(nameF16.find("_fp16") != std::string::npos); + testAssert(nameF32.size() >= 4 && nameF32.substr(nameF32.size()-4) == ".txt"); + testAssert(nameF16.size() >= 4 && nameF16.substr(nameF16.size()-4) == ".txt"); + cout << " defaultFileName dtype suffix OK: " + << nameF32 << " vs " << nameF16 << endl; + } + + // Fast (coarse) and full (wide) tunes must not collide on disk. The fast tune + // keeps the legacy name (no mode suffix) for backward compat; full gains + // "_full". Verify the two are distinct and the fast default is unchanged. + { + std::string fast = MLXWinogradTuner::defaultFileName( + "AppleSilicon", 19, 19, 384, 13, /*useFP16=*/true, /*full=*/false); + std::string full = MLXWinogradTuner::defaultFileName( + "AppleSilicon", 19, 19, 384, 13, /*useFP16=*/true, /*full=*/true); + std::string legacy = MLXWinogradTuner::defaultFileName( + "AppleSilicon", 19, 19, 384, 13, /*useFP16=*/true); // default full=false + testAssert(fast != full); + testAssert(fast == legacy); // fast == legacy name + testAssert(full.find("_full") != std::string::npos); + testAssert(fast.find("_full") == std::string::npos); + testAssert(full.size() >= 4 && full.substr(full.size()-4) == ".txt"); + cout << " defaultFileName fast/full suffix OK: " << fast << " vs " << full << endl; + } + + // detectGpuName() must yield a stable, non-empty, chip-specific cache key so + // that different Apple chips don't share one Winograd cache file, and must be + // filesystem-safe once threaded through defaultFileName (no spaces). + { + std::string gpu = MLXWinogradTuner::detectGpuName(); + testAssert(!gpu.empty()); + // Deterministic: the backend load path and the `tuner` command must derive + // the identical key on the same machine. + testAssert(MLXWinogradTuner::detectGpuName() == gpu); + std::string name = MLXWinogradTuner::defaultFileName(gpu, 19, 19, 384, 13, /*useFP16=*/true); + testAssert(name.find(' ') == std::string::npos); + testAssert(name.find("_gpu") != std::string::npos); + testAssert(name.size() >= 4 && name.substr(name.size()-4) == ".txt"); + cout << " detectGpuName OK: \"" << gpu << "\" -> " << name << endl; + } + + // ---- Corrupt-version rejection ---- + { + std::string tmp = "/tmp/katago_mlx_winotuner_badversion.txt"; + { + std::ofstream f(tmp); + f << "VERSION=999\n#inputTransform\ntg0=32 tg1=1\n#outputUntransform\ntg0=32 tg1=1\n"; + } + bool threw = false; + try { (void)MLXWinogradTuneParams::load(tmp); } + catch(const IOError&) { threw = true; } + testAssert(threw); + } + + // ---- v3 isValid invariants ---- + { + // v3 isValid invariants. + using namespace MLXWinograd; + auto basePass = [&]() { + MLXWinogradTuneParams p; + p.inputTransform = {32, 1, 1, 2, GridOrder::Cfast}; + p.outputUntransform = {32, 2, 1}; + return p; + }; + + // Baseline passes. + testAssert(basePass().isValid()); + + // tg0 <= 0 fails. + { auto p = basePass(); p.inputTransform.tg0 = 0; testAssert(!p.isValid()); } + { auto p = basePass(); p.outputUntransform.tg0 = -1; testAssert(!p.isValid()); } + + // tg0 * tg1 > 1024 fails. + { auto p = basePass(); p.inputTransform.tg0 = 64; p.inputTransform.tg1 = 32; + testAssert(!p.isValid()); } + + // wpt < 1 fails. + { auto p = basePass(); p.inputTransform.wpt = 0; testAssert(!p.isValid()); } + { auto p = basePass(); p.outputUntransform.wpt = 0; testAssert(!p.isValid()); } + + // vw < 1 fails on input. + { auto p = basePass(); p.inputTransform.vw = 0; testAssert(!p.isValid()); } + + // Tfast on input forces vw=1. + { auto p = basePass(); + p.inputTransform.gridOrder = GridOrder::Tfast; + p.inputTransform.vw = 2; + testAssert(!p.isValid()); } + { auto p = basePass(); + p.inputTransform.gridOrder = GridOrder::Tfast; + p.inputTransform.vw = 1; + testAssert(p.isValid()); } + + cout << " v3 isValid invariants OK" << endl; + } + + // Candidate enumeration with validity filtering. + { + using namespace MLXWinograd; + // Cfast, C=64 (divisible by all vw): full Cartesian product over all axes + // minus tg0*tg1>1024. + auto cands = MLXWinogradTuner::buildInputCandidatesForTesting( + /*full*/true, /*C*/64, /*Ntiles*/200, GridOrder::Cfast); + + // Sanity: returns hundreds of valid configs. + testAssert(cands.size() > 100); + testAssert(cands.size() < 5000); // bounded by validity filter + + // All candidates satisfy tg0*tg1 <= 1024. + for(const auto& c : cands) + testAssert(c.tg0 * c.tg1 <= 1024); + + // C=66 with vw>1: should filter out vw=2 (66%2=0 — VW=2 allowed) + // and vw=4 (66%4=2 != 0 — VW=4 should NOT appear in candidates). + auto cands_C66 = MLXWinogradTuner::buildInputCandidatesForTesting( + true, /*C*/66, /*Ntiles*/200, GridOrder::Cfast); + for(const auto& c : cands_C66) { + if(c.vw == 4) + testAssert(false); // vw=4 candidate should have been filtered out for C=66 + } + + // Tfast: vw must be 1 (kernel static_assert). All Tfast candidates have vw=1. + auto cands_Tfast = MLXWinogradTuner::buildInputCandidatesForTesting( + true, 64, 200, GridOrder::Tfast); + for(const auto& c : cands_Tfast) { + testAssert(c.vw == 1); + testAssert(c.gridOrder == GridOrder::Tfast); + } + + // Output side: same shape of assertions. (gridOrder is not a parameter + // of buildOutputCandidatesForTesting — output is Cfast-only.) + auto out_cands = MLXWinogradTuner::buildOutputCandidatesForTesting( + true, /*outC*/64, /*Ntiles*/200); + testAssert(out_cands.size() > 100); + for(const auto& c : out_cands) + testAssert(c.tg0 * c.tg1 <= 1024); + + std::cout << " MLX Winograd candidate enumeration validity passed (" + << cands.size() << " input / " << out_cands.size() << " output candidates C=64)" + << std::endl; + } + + // ---- Measurement primitives return finite positive times ---- + // We can't call the static helpers from the test, so we use the public + // surface: loadOrAutoTune with reTune=true runs the search and we verify + // that the public schema struct works with valid configs. The measurement + // primitive itself is exercised by the search-works test below. + + { + // Gated flat-sweep convergence test. + // Runs the production flat sweep on a small synthetic problem and asserts + // that the winner is isValid and that its timing is no worse than the + // baked default (tg0=32, tg1=1, wpt=1, vw=1, Cfast). + const char* gate = std::getenv("KATAGO_MLX_WINOTUNER_RUN_SWEEP_TEST"); + if(gate != nullptr && std::string(gate) == "1") { + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = 64; + mi.modelVersion = 11; + // Synthetic single-shape histogram for the toy C=64 test model. + mi.conv3x3InputHistogram = {{64, 1}}; + mi.conv3x3OutputHistogram = {{64, 1}}; + + // loadOrAutoTune rewrites an empty tunerFile to a default cache path, + // so use an explicit temp path and remove it after to avoid touching + // the user's cache directory. + std::string tmpTunerFile = "/tmp/katago_mlx_winotuner_sweep_cache.txt"; + std::remove(tmpTunerFile.c_str()); + + MLXWinogradTuneParams tuned = MLXWinogradTuner::loadOrAutoTune( + /*tunerFile=*/tmpTunerFile, + /*homeDataDirOverride=*/"", + /*gpuName=*/"AppleSilicon", + /*nnXLen=*/19, /*nnYLen=*/19, /*batchSize=*/1, + mi, + /*logger=*/nullptr, + /*full=*/false, + /*reTune=*/true, + /*useFP16=*/true); + testAssert(tuned.isValid()); + + // Score the baked default and the tuned winner via scoreInputTransform. + // tuned.time <= baked.time (within noise). + MLXWinograd::InputTransform baked{}; + baked.tg0 = 32; baked.tg1 = 1; baked.wpt = 1; baked.vw = 1; + baked.gridOrder = MLXWinograd::GridOrder::Cfast; + auto bestOf5 = [&](const MLXWinograd::InputTransform& cfg) -> double { + double best = std::numeric_limits::infinity(); + for(int rep = 0; rep < 5; rep++) { + double t = MLXWinogradTuner::scoreInputTransformForTesting( + cfg, 1, 19, 19, mi, true); + if(t < best) best = t; + } + return best; + }; + double bakedMs = bestOf5(baked); + double tunedMs = bestOf5(tuned.inputTransform); + // The sweep seeds the default as its floor and only replaces it with a + // strictly-faster candidate, so the winner is <= default in the sweep's + // OWN measurement by construction. Here we re-measure both (bestOf5) as + // an independent sanity check that the winner isn't grossly worse. On the + // coarse auto grid the winner on this toy single-shape problem is often + // within noise of the default, so the two bestOf5 values are two noisy + // sub-millisecond samples of near-tied configs (each carries the ~10-15% + // dispatch/sync-overhead noise floor — the same noise behind the tuner's + // plateau). A 1.10 bound flips intermittently on both this and the + // pre-change binary (observed ratios to ~1.06 even on the wide grid); a + // 1.30 bound absorbs that while still catching a winner that is genuinely + // much slower than default (which would signal a real sweep/scoring bug). + testAssert(tunedMs <= bakedMs * 1.30); + std::cout << " flat-sweep convergence (gated) OK" + << " bakedMs=" << bakedMs + << " tunedMs=" << tunedMs << std::endl; + + std::remove(tmpTunerFile.c_str()); + } + } + + { + // Baseline anchor — Test 1: log-format gated check (input stage). + // Asserts that flatSweepInput's log line carries the new baseline_ms and + // delta_pct fields with the documented format. Gated because the synthetic + // sweep takes a few seconds; opt in with the env var below. + const char* gate = std::getenv("KATAGO_MLX_WINOTUNER_RUN_LOG_FORMAT_TEST"); + if(gate != nullptr && std::string(gate) == "1") { + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = 64; + mi.modelVersion = 11; + // Synthetic single-shape histogram for the toy C=64 test model. + mi.conv3x3InputHistogram = {{64, 1}}; + mi.conv3x3OutputHistogram = {{64, 1}}; + + std::string tmpTunerFile = "/tmp/baseline_anchor_log_format.txt"; + std::remove(tmpTunerFile.c_str()); + + std::ostringstream captured; + Logger logger(nullptr, /*logToStdoutDefault=*/false, + /*logToStderrDefault=*/false, /*logTimeDefault=*/false, + /*logConfigContents=*/false); + logger.addOStream(captured); + + (void)MLXWinogradTuner::loadOrAutoTune( + /*tunerFile=*/tmpTunerFile, + /*homeDataDirOverride=*/"", + /*gpuName=*/"AppleSilicon", + /*nnXLen=*/19, /*nnYLen=*/19, /*batchSize=*/1, + mi, + /*logger=*/&logger, + /*full=*/false, + /*reTune=*/true, + /*useFP16=*/true); + + const std::string log = captured.str(); + // Logger::writeLocked prefixes each line with ": " when logTime=false, so + // `log` reads ": MLX tuner ...". std::regex_search is anchor-free so the + // ": " prefix is transparent; only the substring match matters here. + // The regex matches the non-degenerate path only (best != nullopt). The + // best=none / delta_pct=nan branch is unreachable for the synthetic 19x19 + // C=64 problem this test runs against (hundreds of valid candidates). + // Updated for shape diagnostic: regex now requires the per-shape + // median fields appended by flatSweepInput. + std::regex inputRe( + R"(MLX tuner flatSweepInput: considered=[0-9]+ best=tg0=[0-9]+ tg1=[0-9]+ wpt=[0-9]+ vw=[0-9]+ gridOrder=[01] time_ms=[0-9]+\.[0-9]+ baseline_ms=[0-9]+\.[0-9]+ delta_pct=[-+][0-9]+\.[0-9]+ shape_ms=c[0-9]+:[0-9]+\.[0-9]+(?:,c[0-9]+:[0-9]+\.[0-9]+)*)"); + testAssert(std::regex_search(log, inputRe)); + std::cout << " flatSweepInput log-format (gated) OK" << std::endl; + + std::regex outputRe( + R"(MLX tuner flatSweepOutput: considered=[0-9]+ best=tg0=[0-9]+ tg1=[0-9]+ wpt=[0-9]+ time_ms=[0-9]+\.[0-9]+ baseline_ms=[0-9]+\.[0-9]+ delta_pct=[-+][0-9]+\.[0-9]+ shape_ms=c[0-9]+:[0-9]+\.[0-9]+(?:,c[0-9]+:[0-9]+\.[0-9]+)*)"); + testAssert(std::regex_search(log, outputRe)); + std::cout << " flatSweepOutput log-format (gated) OK" << std::endl; + + std::remove(tmpTunerFile.c_str()); + } + } + + { + // Baseline anchor — Test 2: baseline-consistency gated check. + // Asserts that the baseline_ms value printed by flatSweepInput is in the + // same ballpark as an independent re-score of the default-constructed + // InputTransform — a gross-error sanity check (catches a ~2x units/logic + // bug in the logged baseline), NOT a precision check. + // + // parsedBaseline is a SINGLE 20-rep weighted mean (one call into + // scoreInputTransform, logged by the sweep). minOf3 is the min of three + // such weighted means — biased low by min-selection (~5-10%), while + // parsedBaseline stays a high-variance single sample. Both carry the ~10% + // per-sample noise floor of these sub-millisecond kernels (steady_clock + // around one dispatch includes fixed dispatch/sync overhead — the same + // noise that makes the tuner's own winner a draw from a plateau). The gap + // is therefore positively skewed: across runs the same-config relErr + // clusters <0.15 but tails to ~0.3 and occasionally past 0.4. A single + // sample cannot support a tight bound, so we use the same 0.50 budget as + // the sibling per-shape check below; it still flags a gross measurement + // bug. Tight precision belongs in same-config bit-for-bit tests, not here. + // + // Reuses the KATAGO_MLX_WINOTUNER_RUN_SWEEP_TEST gate so users who + // opt into the sweep-convergence cost also get this check. Note + // this runs an INDEPENDENT loadOrAutoTune sweep — total cost when + // the gate is set is roughly 2x the cost of a single sweep. + // + // Coverage scope: input stage only. flatSweepOutput's baseline_ms + // is format-checked by Test 1 but not consistency-checked here. + // The output kernel uses a different scoring function and default + // struct (OutputUntransform{}); a symmetric check is deferred. + const char* gate = std::getenv("KATAGO_MLX_WINOTUNER_RUN_SWEEP_TEST"); + if(gate != nullptr && std::string(gate) == "1") { + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = 64; + mi.modelVersion = 11; + // Synthetic single-shape histogram for the toy C=64 test model. + mi.conv3x3InputHistogram = {{64, 1}}; + mi.conv3x3OutputHistogram = {{64, 1}}; + + std::string tmpTunerFile = "/tmp/baseline_anchor_consistency.txt"; + std::remove(tmpTunerFile.c_str()); + + std::ostringstream captured; + Logger logger(nullptr, /*logToStdoutDefault=*/false, + /*logToStderrDefault=*/false, /*logTimeDefault=*/false, + /*logConfigContents=*/false); + logger.addOStream(captured); + + (void)MLXWinogradTuner::loadOrAutoTune( + /*tunerFile=*/tmpTunerFile, + /*homeDataDirOverride=*/"", + /*gpuName=*/"AppleSilicon", + /*nnXLen=*/19, /*nnYLen=*/19, /*batchSize=*/1, + mi, + /*logger=*/&logger, + /*full=*/false, + /*reTune=*/true, + /*useFP16=*/true); + + const std::string log = captured.str(); + std::smatch m; + std::regex baselineRe(R"(flatSweepInput:[^\n]*baseline_ms=([0-9]+\.[0-9]+))"); + testAssert(std::regex_search(log, m, baselineRe)); + const double parsedBaseline = std::stod(m[1].str()); + + double minOf3 = std::numeric_limits::infinity(); + for(int rep = 0; rep < 3; rep++) { + double t = MLXWinogradTuner::scoreInputTransformForTesting( + MLXWinograd::InputTransform{}, 1, 19, 19, mi, true); + if(t < minOf3) minOf3 = t; + } + + const double relErr = std::abs(parsedBaseline - minOf3) / minOf3; + testAssert(relErr < 0.50); + std::cout << " baseline-consistency (gated) OK" + << " parsed=" << parsedBaseline + << " minOf3=" << minOf3 + << " relErr=" << relErr << std::endl; + + std::remove(tmpTunerFile.c_str()); + } + } + + { + // Per-shape numeric consistency — Test 2 from the shape-diagnostic spec. + // Asserts the dominant-shape median printed by flatSweepInput + // (shape_ms=c:) is in the same ballpark as an independent + // reference measurement of the default InputTransform{} on that shape. + // + // IMPORTANT — cross-config comparison: parsedDominantMs is measured by + // flatSweepInput on the WINNER configuration (whatever the sweep + // selected). minOf3 is computed on the DEFAULT InputTransform{} via + // three independent scoreInputTransformPerShapeForTesting calls. These + // are not the same config, so the relative-error budget is necessarily + // loose. The budget covers: + // - winner-vs-default speed gap (sweep can find configs 10-40% + // faster than default on some shapes/hardware) + // - selection bias on the min-of-3 reference (~5-10% low vs single) + // - per-call noise floor (~10%) + // The 50% budget is intentionally conservative; this is a sanity-check + // that measurement is roughly working, not a tight precision check. + // Tighter precision checks belong in same-config stability tests. + // + // Coverage scope: input stage only. flatSweepOutput's per-shape fields + // are format-checked by the log-format test (gate + // KATAGO_MLX_WINOTUNER_RUN_LOG_FORMAT_TEST) but not consistency- + // checked here — symmetric output check is deferred. + // + // Gate is new (KATAGO_MLX_WINOTUNER_RUN_PER_SHAPE_TEST) and separate + // from the baseline-anchor gate above; this test runs an additional + // tuner sweep. + const char* gate = std::getenv("KATAGO_MLX_WINOTUNER_RUN_PER_SHAPE_TEST"); + if(gate != nullptr && std::string(gate) == "1") { + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = 64; + mi.modelVersion = 11; + // Synthetic single-shape histogram for the toy C=64 test model. + mi.conv3x3InputHistogram = {{64, 1}}; + mi.conv3x3OutputHistogram = {{64, 1}}; + + std::string tmpTunerFile = "/tmp/per_shape_consistency.txt"; + std::remove(tmpTunerFile.c_str()); + + std::ostringstream captured; + Logger logger(nullptr, /*logToStdoutDefault=*/false, + /*logToStderrDefault=*/false, /*logTimeDefault=*/false, + /*logConfigContents=*/false); + logger.addOStream(captured); + + (void)MLXWinogradTuner::loadOrAutoTune( + /*tunerFile=*/tmpTunerFile, + /*homeDataDirOverride=*/"", + /*gpuName=*/"AppleSilicon", + /*nnXLen=*/19, /*nnYLen=*/19, /*batchSize=*/1, + mi, + /*logger=*/&logger, + /*full=*/false, + /*reTune=*/true, + /*useFP16=*/true); + + const std::string log = captured.str(); + std::smatch m; + std::regex trunkRe(R"(flatSweepInput:[^\n]*shape_ms=c[0-9]+:([0-9]+\.[0-9]+))"); + testAssert(std::regex_search(log, m, trunkRe)); + const double parsedDominantMs = std::stod(m[1].str()); + + // Per-shape consistency: parse the dominant shape's median from + // the flatSweepInput log line (which used scoreInputTransformPerShape + // on the winner) and compare against scoreInputTransformPerShapeForTesting + // on the default InputTransform. Cross-config (winner vs default) + // so a wide relErr bound (<0.50) is appropriate. + std::vector> r1 = + MLXWinogradTuner::scoreInputTransformPerShapeForTesting( + MLXWinograd::InputTransform{}, 1, 19, 19, mi, true); + std::vector> r2 = + MLXWinogradTuner::scoreInputTransformPerShapeForTesting( + MLXWinograd::InputTransform{}, 1, 19, 19, mi, true); + std::vector> r3 = + MLXWinogradTuner::scoreInputTransformPerShapeForTesting( + MLXWinograd::InputTransform{}, 1, 19, 19, mi, true); + testAssert(!r1.empty() && !r2.empty() && !r3.empty()); + // Each result has the same shapes in the same order; take the + // dominant (index 0) per-shape median across the 3 runs. + double minOf3 = std::min({r1[0].second, r2[0].second, r3[0].second}); + + const double relErr = std::abs(parsedDominantMs - minOf3) / minOf3; + // 50% budget — see comment block above for rationale on the loose + // bound (cross-config comparison + selection bias + noise). + testAssert(relErr < 0.50); + std::cout << " per-shape dominant consistency (gated) OK" + << " parsed=" << parsedDominantMs + << " minOf3=" << minOf3 + << " relErr=" << relErr << std::endl; + + std::remove(tmpTunerFile.c_str()); + } + } + + { + // Per-shape scoring smoke test: verify that scoreInputTransformPerShape + // and scoreOutputUntransformPerShape return finite positive values for + // each planned shape with a default-constructed + // InputTransform/OutputUntransform on a tiny shape. Gated under the same + // env var as the other GPU-touching tests; ungated CI shouldn't pay for + // GPU work. + const char* gate = std::getenv("KATAGO_MLX_WINOTUNER_RUN_SWEEP_TEST"); + if(gate != nullptr && std::string(gate) == "1") { + MLXWinogradTuner::ModelInfoForTuning mi; + mi.trunkNumChannels = 64; + mi.modelVersion = 11; + // Synthetic single-shape histogram for the toy C=64 test model. + mi.conv3x3InputHistogram = {{64, 1}}; + mi.conv3x3OutputHistogram = {{64, 1}}; + + std::vector> in = + MLXWinogradTuner::scoreInputTransformPerShapeForTesting( + MLXWinograd::InputTransform{}, 1, 19, 19, mi, true); + testAssert(!in.empty()); + for(const auto& [c, v] : in) { + testAssert(c > 0); + testAssert(std::isfinite(v)); + testAssert(v > 0.0); + testAssert(v < 1000.0); // sanity: <1s per call on Apple Silicon + } + + std::vector> out = + MLXWinogradTuner::scoreOutputUntransformPerShapeForTesting( + MLXWinograd::OutputUntransform{}, 1, 19, 19, mi, true); + testAssert(!out.empty()); + for(const auto& [c, v] : out) { + testAssert(c > 0); + testAssert(std::isfinite(v)); + testAssert(v > 0.0); + testAssert(v < 1000.0); + } + std::cout << " per-shape scoring smoke (gated) OK" + << " in[0]=c" << in[0].first << ":" << in[0].second + << " out[0]=c" << out[0].first << ":" << out[0].second + << std::endl; + } + } + + cout << "MLX Winograd tuner tests passed" << endl; +} + +#endif // USE_MLX_BACKEND diff --git a/cpp/neuralnet/mlxwinograd.h b/cpp/neuralnet/mlxwinograd.h new file mode 100644 index 000000000..2be32a4f7 --- /dev/null +++ b/cpp/neuralnet/mlxwinograd.h @@ -0,0 +1,705 @@ +#ifndef NEURALNET_MLXWINOGRAD_H_ +#define NEURALNET_MLXWINOGRAD_H_ + +#ifdef USE_MLX_BACKEND + +#include + +namespace MLXWinograd { + +enum class GridOrder : int { Cfast = 0, Tfast = 1 }; + +// Per-stage launch-geometry configs. Input transform exposes +// (tg0, tg1, wpt, vw, gridOrder); output untransform exposes (tg0, tg1, wpt). +// The output kernel is monomorphic on VW=1, GRID_ORDER=Cfast, and the +// matmul layout is monomorphic on Std for both stages. +struct InputTransform { + int tg0 = 32; + int tg1 = 1; + int wpt = 1; // tiles per thread; {1, 2, 4, 8} + int vw = 1; // vector width; {1, 2, 4} + GridOrder gridOrder = GridOrder::Cfast; +}; +struct OutputUntransform { + int tg0 = 32; + int tg1 = 1; + int wpt = 1; +}; + +// F(2,3) 1D transform matrices. +inline constexpr float BT[4][4] = { + {1.f, 0.f,-1.f, 0.f}, + {0.f, 1.f, 1.f, 0.f}, + {0.f,-1.f, 1.f, 0.f}, + {0.f, 1.f, 0.f,-1.f} +}; +inline constexpr float G[4][3] = { + {1.f, 0.f, 0.f}, + {0.5f,0.5f,0.5f}, + {0.5f,-0.5f,0.5f}, + {0.f, 0.f, 1.f} +}; +inline constexpr float AT[2][4] = { + {1.f, 1.f, 1.f, 0.f}, + {0.f, 1.f,-1.f,-1.f} +}; + +// Transform one 3x3 filter g -> 4x4 U = G g G^T. +inline void transformWeight(const float g[3][3], float U[4][4]) { + float Gg[4][3]; + for(int i=0;i<4;i++) for(int j=0;j<3;j++) { + float s=0.f; for(int k=0;k<3;k++) s += G[i][k]*g[k][j]; Gg[i][j]=s; + } + for(int i=0;i<4;i++) for(int j=0;j<4;j++) { + float s=0.f; for(int k=0;k<3;k++) s += Gg[i][k]*G[j][k]; U[i][j]=s; + } +} + +// Transform one 4x4 input tile d -> 4x4 V = B^T d B. +inline void transformInput(const float d[4][4], float V[4][4]) { + float Bd[4][4]; + for(int i=0;i<4;i++) for(int j=0;j<4;j++) { + float s=0.f; for(int k=0;k<4;k++) s += BT[i][k]*d[k][j]; Bd[i][j]=s; + } + for(int i=0;i<4;i++) for(int j=0;j<4;j++) { + float s=0.f; for(int k=0;k<4;k++) s += Bd[i][k]*BT[j][k]; V[i][j]=s; + } +} + +// Inverse transform 4x4 M -> 2x2 Y = A^T M A. +inline void transformOutput(const float M[4][4], float Y[2][2]) { + float AM[2][4]; + for(int i=0;i<2;i++) for(int j=0;j<4;j++) { + float s=0.f; for(int k=0;k<4;k++) s += AT[i][k]*M[k][j]; AM[i][j]=s; + } + for(int i=0;i<2;i++) for(int j=0;j<2;j++) { + float s=0.f; for(int k=0;k<4;k++) s += AM[i][k]*AT[j][k]; Y[i][j]=s; + } +} + +// Full CPU reference NHWC Winograd F(2,3) "same" conv, stride 1. +// in: [N][H][W][Cin], weights OIHW flattened [Cout][Cin][3][3], out: [N][H][W][Cout]. +inline std::vector cpuConv2d3x3( + const std::vector& in, int N, int H, int W, int Cin, + const std::vector& wOIHW, int Cout +) { + std::vector out((size_t)N*H*W*Cout, 0.f); + // Precompute U per (oc,ic). + std::vector U((size_t)Cout*Cin*16); + for(int oc=0;oc=0&&iy=0&&ix U array. +// Layout: [16, Cin, Cout] — Cout fast (matmul sees [16,Ntiles,Cin] x [16,Cin,Cout] -> [16,Ntiles,Cout]). +// Output layout: Std only. +inline mx::array makeWinogradWeights(const std::vector& wOIHW, + int Cout, int Cin, + bool useFP16 = false) { + std::vector U((size_t)16 * Cin * Cout, 0.0f); + for(int oc = 0; oc < Cout; oc++) { + for(int ic = 0; ic < Cin; ic++) { + float g[3][3]; + for(int a = 0; a < 3; a++) + for(int b = 0; b < 3; b++) + g[a][b] = wOIHW[(((size_t)oc * Cin + ic) * 3 + a) * 3 + b]; + float Um[4][4]; transformWeight(g, Um); + for(int a = 0; a < 4; a++) { + for(int b = 0; b < 4; b++) { + // [16, Cin, Cout] — Cout fast + size_t idx = ((size_t)(a * 4 + b) * Cin + ic) * Cout + oc; + U[idx] = Um[a][b]; + } + } + } + } + mx::Shape shape = {16, Cin, Cout}; + mx::array arr(U.data(), shape, mx::float32); + if(useFP16) { + mx::array casted = mx::astype(arr, mx::float16); + // Realize on the constructor thread so the resulting array is a + // materialized constant. Without this, a model cached and shared + // across threads carries an unevaluated AsType primitive that is + // stamped with the constructor thread's stream — calling thread's + // mx::eval then fails with "There is no Stream(gpu, N) in current + // thread." for the constructor thread's stream index. + mx::eval(casted); + return casted; + } + return arr; +} + +// F(2,3) input transform kernel: NHWC T input -> [16, Ntiles, C] T output. +// The matmul layout is monomorphic on Std ([16, Ntiles, C]). +// Template args (JIT-substituted via MLX template_args): +// T — float or half (precision) +// WPT — tiles per thread +// VW — vector width for packed loads +// GRID_ORDER — 0=Cfast (C is fast axis), 1=Tfast (Ntiles fast) +// Grid: +// Cfast: (ceil(C/VW), ceil(Ntiles/WPT), 1) +// Tfast: (Ntiles, ceil(C/WPT), 1) +inline constexpr const char* kWinoInputSource = R"METAL( + static_assert(WPT >= 1 && VW >= 1, "WPT and VW must be positive"); + // Tfast (GRID_ORDER=1) does not support VW>1. + static_assert(GRID_ORDER == 0 || VW == 1, "Tfast (GRID_ORDER=1) requires VW=1"); + + int N_k = inp_shape[0]; + int H_k = inp_shape[1]; + int W_k = inp_shape[2]; + int C_k = inp_shape[3]; + int tilesY_k = (H_k + 1) / 2; + int tilesX_k = (W_k + 1) / 2; + int Ntiles_k = N_k * tilesY_k * tilesX_k; + + if (GRID_ORDER == 0) { + // Cfast: grid x = ceil(C/VW), grid y = ceil(Ntiles/WPT). + // Each thread owns VW channels (inner vc loop) and WPT tiles (outer w loop). + uint c_group = thread_position_in_grid.x; + uint t_group = thread_position_in_grid.y; + + for (int w = 0; w < WPT; w++) { + int tileIdx = (int)t_group * WPT + w; + if (tileIdx >= Ntiles_k) break; + + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + + for (int vc = 0; vc < VW; vc++) { + int c = (int)c_group * VW + vc; + if (c >= C_k) break; + T d[4][4]; + for (int i = 0; i < 4; i++) { + int iy = 2 * ty - 1 + i; + for (int j = 0; j < 4; j++) { + int ix = 2 * tx - 1 + j; + if (iy < 0 || iy >= H_k || ix < 0 || ix >= W_k) { + d[i][j] = (T)0.0f; + } else { + d[i][j] = inp[((n * H_k + iy) * W_k + ix) * C_k + c]; + } + } + } + // Transform accumulates in fp32; only the stored V rounds to T (fp16-safe). + float tmp[4][4]; + for (int j = 0; j < 4; j++) { + float v0 = (float)d[0][j], v1 = (float)d[1][j], v2 = (float)d[2][j], v3 = (float)d[3][j]; + tmp[0][j] = v0 - v2; + tmp[1][j] = v1 + v2; + tmp[2][j] = v2 - v1; + tmp[3][j] = v1 - v3; + } + for (int r = 0; r < 4; r++) { + float u0 = tmp[r][0], u1 = tmp[r][1], u2 = tmp[r][2], u3 = tmp[r][3]; + float V0 = u0 - u2; + float V1 = u1 + u2; + float V2 = u2 - u1; + float V3 = u1 - u3; + // outp [16, Ntiles, C] — C is the fast axis. + int base = ((r * 4 + 0) * Ntiles_k + tileIdx) * C_k + c; + outp[base + 0 * Ntiles_k * C_k] = (T)V0; + outp[base + 1 * Ntiles_k * C_k] = (T)V1; + outp[base + 2 * Ntiles_k * C_k] = (T)V2; + outp[base + 3 * Ntiles_k * C_k] = (T)V3; + } + } + } + } else { + // Tfast: grid x = Ntiles, grid y = ceil(C/WPT). VW must be 1 (enforced + // by the static_assert above). + uint t_group_ = thread_position_in_grid.x; + uint c_group_ = thread_position_in_grid.y; + int tileIdx = (int)t_group_; + if (tileIdx >= Ntiles_k) return; + + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + + for (int w = 0; w < WPT; w++) { + int c = (int)c_group_ * WPT + w; + if (c >= C_k) break; + T d[4][4]; + for (int i = 0; i < 4; i++) { + int iy = 2 * ty - 1 + i; + for (int j = 0; j < 4; j++) { + int ix = 2 * tx - 1 + j; + if (iy < 0 || iy >= H_k || ix < 0 || ix >= W_k) { + d[i][j] = (T)0.0f; + } else { + d[i][j] = inp[((n * H_k + iy) * W_k + ix) * C_k + c]; + } + } + } + // Transform accumulates in fp32; only the stored V rounds to T (fp16-safe). + float tmp[4][4]; + for (int j = 0; j < 4; j++) { + float v0 = (float)d[0][j], v1 = (float)d[1][j], v2 = (float)d[2][j], v3 = (float)d[3][j]; + tmp[0][j] = v0 - v2; + tmp[1][j] = v1 + v2; + tmp[2][j] = v2 - v1; + tmp[3][j] = v1 - v3; + } + for (int r = 0; r < 4; r++) { + float u0 = tmp[r][0], u1 = tmp[r][1], u2 = tmp[r][2], u3 = tmp[r][3]; + float V0 = u0 - u2; + float V1 = u1 + u2; + float V2 = u2 - u1; + float V3 = u1 - u3; + // outp [16, Ntiles, C] — C is the fast axis. + int base = ((r * 4 + 0) * Ntiles_k + tileIdx) * C_k + c; + outp[base + 0 * Ntiles_k * C_k] = (T)V0; + outp[base + 1 * Ntiles_k * C_k] = (T)V1; + outp[base + 2 * Ntiles_k * C_k] = (T)V2; + outp[base + 3 * Ntiles_k * C_k] = (T)V3; + } + } + } +)METAL"; + +// F(2,3) output untransform kernel: [16, Ntiles, outC] T input -> NHWC T output. +// Template args (JIT-substituted via MLX template_args): +// T — float or half (precision) +// WPT — tiles per thread +// Grid: (Cout, ceil(Ntiles/WPT), 1). +// nhwc input array carries the [N,H,W,outC] dims because metal_kernel only +// exposes *_shape for inputs, not outputs. +// The output kernel is monomorphic on VW=1, GRID_ORDER=Cfast, and matmul +// layout=Std. (GRID_ORDER=Cfast was chosen from an empirical sensitivity +// sweep showing <1% delta vs Tfast; the other two are structural.) +inline constexpr const char* kWinoOutputSource = R"METAL( + static_assert(WPT >= 1, "WPT must be positive"); + + // m shape [16, Ntiles, outC] — Ntiles=m_shape[1], outC=m_shape[2]. + int Ntiles_k = m_shape[1]; + int outC_k = m_shape[2]; + int H_k = nhwc[1]; + int W_k = nhwc[2]; + int tilesY_k = (H_k + 1) / 2; + int tilesX_k = (W_k + 1) / 2; + + // Cfast: grid x = Cout, grid y = ceil(Ntiles/WPT). + uint oc_group = thread_position_in_grid.x; + uint t_group = thread_position_in_grid.y; + + for (int w = 0; w < WPT; w++) { + int tileIdx = (int)t_group * WPT + w; + if (tileIdx >= Ntiles_k) break; + + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + + { + int oc = (int)oc_group; + if (oc >= outC_k) break; + + T mm[4][4]; + for (int r = 0; r < 4; r++) { + for (int c2 = 0; c2 < 4; c2++) { + int p = r * 4 + c2; + // m shape [16, Ntiles, outC]. + mm[r][c2] = m[(p * Ntiles_k + tileIdx) * outC_k + oc]; + } + } + // Untransform accumulates in fp32; only the stored Y rounds to T (fp16-safe). + float tmp[2][4]; + for (int c2 = 0; c2 < 4; c2++) { + float v0 = (float)mm[0][c2], v1 = (float)mm[1][c2], v2 = (float)mm[2][c2], v3 = (float)mm[3][c2]; + tmp[0][c2] = v0 + v1 + v2; + tmp[1][c2] = v1 - v2 - v3; + } + for (int a = 0; a < 2; a++) { + float u0 = tmp[a][0], u1 = tmp[a][1], u2 = tmp[a][2], u3 = tmp[a][3]; + float Y0 = u0 + u1 + u2; + float Y1 = u1 - u2 - u3; + int oy0 = 2 * ty + a; + if (oy0 < H_k) { + int ox0 = 2 * tx + 0; + if (ox0 < W_k) + outp[((n * H_k + oy0) * W_k + ox0) * outC_k + oc] = (T)Y0; + int ox1 = 2 * tx + 1; + if (ox1 < W_k) + outp[((n * H_k + oy0) * W_k + ox1) * outC_k + oc] = (T)Y1; + } + } + } + } +)METAL"; + +// Output untransform + fused BN/activation epilogue. Extra inputs: scale,bias +// (fp32 [outC]). Template arg ACT: 0 identity, 1 mish, 2 relu. The epilogue +// consumes the rounded-to-T conv value (float)(T)Y to match the unfused path +// (which stores the conv as T, then a separate BN kernel reads it). +inline constexpr const char* kWinoOutputSourceBNAct = R"METAL( + static_assert(WPT >= 1, "WPT must be positive"); + int Ntiles_k = m_shape[1]; + int outC_k = m_shape[2]; + int H_k = nhwc[1]; + int W_k = nhwc[2]; + int tilesY_k = (H_k + 1) / 2; + int tilesX_k = (W_k + 1) / 2; + uint oc_group = thread_position_in_grid.x; + uint t_group = thread_position_in_grid.y; + for (int w = 0; w < WPT; w++) { + int tileIdx = (int)t_group * WPT + w; + if (tileIdx >= Ntiles_k) break; + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + { + int oc = (int)oc_group; + if (oc >= outC_k) break; + T mm[4][4]; + for (int r = 0; r < 4; r++) + for (int c2 = 0; c2 < 4; c2++) + mm[r][c2] = m[((r*4+c2) * Ntiles_k + tileIdx) * outC_k + oc]; + float tmp[2][4]; + for (int c2 = 0; c2 < 4; c2++) { + float v0=(float)mm[0][c2], v1=(float)mm[1][c2], v2=(float)mm[2][c2], v3=(float)mm[3][c2]; + tmp[0][c2] = v0 + v1 + v2; + tmp[1][c2] = v1 - v2 - v3; + } + float sc = (float)scale[oc]; + float bi = (float)bias[oc]; + for (int a = 0; a < 2; a++) { + float u0=tmp[a][0], u1=tmp[a][1], u2=tmp[a][2], u3=tmp[a][3]; + float Y0 = u0 + u1 + u2; + float Y1 = u1 - u2 - u3; + // Round to T first (match unfused stored-then-read), then BN+act in fp32. + float x0 = (float)(T)Y0, x1 = (float)(T)Y1; + x0 = x0*sc + bi; x1 = x1*sc + bi; + if (ACT == 1) { // mish: x * tanh(softplus(x)), softplus = logaddexp(0,x) + float s0 = metal::max(0.0f,x0) + metal::precise::log(1.0f + metal::precise::exp(-metal::abs(x0))); + float s1 = metal::max(0.0f,x1) + metal::precise::log(1.0f + metal::precise::exp(-metal::abs(x1))); + x0 = x0 * metal::precise::tanh(s0); + x1 = x1 * metal::precise::tanh(s1); + } else if (ACT == 2) { // relu + x0 = metal::max(0.0f,x0); x1 = metal::max(0.0f,x1); + } + int oy0 = 2*ty + a; + if (oy0 < H_k) { + int ox0 = 2*tx + 0; + if (ox0 < W_k) outp[((n*H_k+oy0)*W_k+ox0)*outC_k + oc] = (T)x0; + int ox1 = 2*tx + 1; + if (ox1 < W_k) outp[((n*H_k+oy0)*W_k+ox1)*outC_k + oc] = (T)x1; + } + } + } + } +)METAL"; + +// Output untransform + fused broadcast-bias add then BN/activation. Identical to +// kWinoOutputSourceBNAct except a per-(n,oc) broadcast bias gbias ([N,outC], T) is +// added to the rounded-to-T conv value before BN+act (matches gpool's regularOut + +// bias). Extra inputs: gbias (T [N,outC]), scale,bias (fp32 [outC]). Template arg +// ACT: 0 identity, 1 mish, 2 relu. +inline constexpr const char* kWinoOutputSourceBiasBNAct = R"METAL( + static_assert(WPT >= 1, "WPT must be positive"); + int Ntiles_k = m_shape[1]; + int outC_k = m_shape[2]; + int H_k = nhwc[1]; + int W_k = nhwc[2]; + int tilesY_k = (H_k + 1) / 2; + int tilesX_k = (W_k + 1) / 2; + uint oc_group = thread_position_in_grid.x; + uint t_group = thread_position_in_grid.y; + for (int w = 0; w < WPT; w++) { + int tileIdx = (int)t_group * WPT + w; + if (tileIdx >= Ntiles_k) break; + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + { + int oc = (int)oc_group; + if (oc >= outC_k) break; + T mm[4][4]; + for (int r = 0; r < 4; r++) + for (int c2 = 0; c2 < 4; c2++) + mm[r][c2] = m[((r*4+c2) * Ntiles_k + tileIdx) * outC_k + oc]; + float tmp[2][4]; + for (int c2 = 0; c2 < 4; c2++) { + float v0=(float)mm[0][c2], v1=(float)mm[1][c2], v2=(float)mm[2][c2], v3=(float)mm[3][c2]; + tmp[0][c2] = v0 + v1 + v2; + tmp[1][c2] = v1 - v2 - v3; + } + float sc = (float)scale[oc]; + float bi = (float)bias[oc]; + for (int a = 0; a < 2; a++) { + float u0=tmp[a][0], u1=tmp[a][1], u2=tmp[a][2], u3=tmp[a][3]; + float Y0 = u0 + u1 + u2; + float Y1 = u1 - u2 - u3; + // Broadcast bias add (T+T, matches gpool's regularOut + bias), then round, then BN+act. + T gb = gbias[n * outC_k + oc]; + float x0 = (float)((T)Y0 + gb), x1 = (float)((T)Y1 + gb); + x0 = x0*sc + bi; x1 = x1*sc + bi; + if (ACT == 1) { // mish: x * tanh(softplus(x)), softplus = logaddexp(0,x) + float s0 = metal::max(0.0f,x0) + metal::precise::log(1.0f + metal::precise::exp(-metal::abs(x0))); + float s1 = metal::max(0.0f,x1) + metal::precise::log(1.0f + metal::precise::exp(-metal::abs(x1))); + x0 = x0 * metal::precise::tanh(s0); + x1 = x1 * metal::precise::tanh(s1); + } else if (ACT == 2) { // relu + x0 = metal::max(0.0f,x0); x1 = metal::max(0.0f,x1); + } + int oy0 = 2*ty + a; + if (oy0 < H_k) { + int ox0 = 2*tx + 0; + if (ox0 < W_k) outp[((n*H_k+oy0)*W_k+ox0)*outC_k + oc] = (T)x0; + int ox1 = 2*tx + 1; + if (ox1 < W_k) outp[((n*H_k+oy0)*W_k+ox1)*outC_k + oc] = (T)x1; + } + } + } + } +)METAL"; + +// Output untransform + fused residual add. Extra input: resid (T [N,H,W,outC]). +// Adds in T arithmetic on the rounded conv value -> bit-identical to unfused +// (T)conv + resid. +inline constexpr const char* kWinoOutputSourceResidual = R"METAL( + static_assert(WPT >= 1, "WPT must be positive"); + int Ntiles_k = m_shape[1]; + int outC_k = m_shape[2]; + int H_k = nhwc[1]; + int W_k = nhwc[2]; + int tilesY_k = (H_k + 1) / 2; + int tilesX_k = (W_k + 1) / 2; + uint oc_group = thread_position_in_grid.x; + uint t_group = thread_position_in_grid.y; + for (int w = 0; w < WPT; w++) { + int tileIdx = (int)t_group * WPT + w; + if (tileIdx >= Ntiles_k) break; + int rem = tileIdx; + int n = rem / (tilesY_k * tilesX_k); rem -= n * tilesY_k * tilesX_k; + int ty = rem / tilesX_k; + int tx = rem % tilesX_k; + { + int oc = (int)oc_group; + if (oc >= outC_k) break; + T mm[4][4]; + for (int r = 0; r < 4; r++) + for (int c2 = 0; c2 < 4; c2++) + mm[r][c2] = m[((r*4+c2) * Ntiles_k + tileIdx) * outC_k + oc]; + float tmp[2][4]; + for (int c2 = 0; c2 < 4; c2++) { + float v0=(float)mm[0][c2], v1=(float)mm[1][c2], v2=(float)mm[2][c2], v3=(float)mm[3][c2]; + tmp[0][c2] = v0 + v1 + v2; + tmp[1][c2] = v1 - v2 - v3; + } + for (int a = 0; a < 2; a++) { + float u0=tmp[a][0], u1=tmp[a][1], u2=tmp[a][2], u3=tmp[a][3]; + float Y0 = u0 + u1 + u2; + float Y1 = u1 - u2 - u3; + int oy0 = 2*ty + a; + if (oy0 < H_k) { + int ox0 = 2*tx + 0; + if (ox0 < W_k) { int idx=((n*H_k+oy0)*W_k+ox0)*outC_k+oc; outp[idx] = (T)Y0 + resid[idx]; } + int ox1 = 2*tx + 1; + if (ox1 < W_k) { int idx=((n*H_k+oy0)*W_k+ox1)*outC_k+oc; outp[idx] = (T)Y1 + resid[idx]; } + } + } + } + } +)METAL"; + +inline mx::array winogradConv2d(const mx::array& input, + const mx::array& Uw, + int Cout, + const InputTransform& inCfg, + const OutputUntransform& outCfg, + bool useFP16 = false, + const Epilogue& epi = Epilogue::none()) { + int N = input.shape(0); + int H = input.shape(1); + int W = input.shape(2); + int C = input.shape(3); + int tilesY = (H + 1) / 2; + int tilesX = (W + 1) / 2; + int Ntiles = N * tilesY * tilesX; + + const mx::Dtype dtype = useFP16 ? mx::float16 : mx::float32; + + auto inSuffix = [&](const char* base, int wpt, int vw, GridOrder go) { + return std::string(base) + "_" + (useFP16 ? "f16" : "f32") + + "_w" + std::to_string(wpt) + + "_v" + std::to_string(vw) + + "_g" + std::to_string((int)go); + }; + // Output kernel is monomorphic on VW=1, GRID_ORDER=Cfast, + // and MATMUL_ORIENT=Std. + auto outSuffix = [&](const char* base, int wpt) { + return std::string(base) + "_" + (useFP16 ? "f16" : "f32") + + "_w" + std::to_string(wpt); + }; + std::string inName = inSuffix ("wino_input_transform", inCfg.wpt, inCfg.vw, inCfg.gridOrder); + + auto makeInTemplateArgs = [&](int wpt, int vw, GridOrder go) { + return std::vector>{ + {"T", dtype}, + {"WPT", wpt}, + {"VW", vw}, + {"GRID_ORDER", (int)go} + }; + }; + + // Stage 1: input transform. Output shape: [16, Ntiles, C]. + mx::Shape inOutShape = {16, Ntiles, C}; + + // Grid: when gridOrder=Cfast the fast axis is C (grid x=C, y=Ntiles/WPT). + // When gridOrder=Tfast we swap. WPT>1 reduces the slow-axis dim. + int gridX_in = (inCfg.gridOrder == GridOrder::Cfast) + ? ((C + inCfg.vw - 1) / inCfg.vw) + : Ntiles; + int gridY_in = (inCfg.gridOrder == GridOrder::Cfast) + ? ((Ntiles + inCfg.wpt - 1) / inCfg.wpt) + : ((C + inCfg.wpt - 1) / inCfg.wpt); + + auto inFn = mx::fast::metal_kernel( + inName.c_str(), + /*input_names=*/{"inp"}, + /*output_names=*/{"outp"}, + /*source=*/kWinoInputSource); + auto inOuts = inFn( + /*inputs=*/{input}, + /*output_shapes=*/{ inOutShape }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX_in, gridY_in, 1), + /*threadgroup=*/std::make_tuple(inCfg.tg0, inCfg.tg1, 1), + /*template_args=*/makeInTemplateArgs(inCfg.wpt, inCfg.vw, inCfg.gridOrder), + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + mx::array t = inOuts[0]; + + // Stage 2: matmul. [16,Ntiles,C] @ [16,C,Cout] -> [16,Ntiles,Cout]. + // MLX steel gemm uses AccumType=float (static-asserted in mma.h:772) when + // T=half, so fp32 accumulation is automatic. + mx::array m = mx::matmul(t, Uw); + + // Stage 3: output untransform (+ optional fused epilogue) -> [N, H, W, Cout] + int nhwc_arr[4] = {N, H, W, Cout}; + mx::array nhwcArr(nhwc_arr, {4}, mx::int32); + int gridX_out = Cout; + int gridY_out = (Ntiles + outCfg.wpt - 1) / outCfg.wpt; + + std::string outName; + const char* outSrc; + std::vector outInputNames; + std::vector outInputs; + std::vector> outTpl = { + {"T", dtype}, {"WPT", outCfg.wpt} + }; + if(epi.mode == Epilogue::None) { + outName = outSuffix("wino_output_untransform", outCfg.wpt); // unchanged name + outSrc = kWinoOutputSource; + outInputNames = {"m", "nhwc"}; + outInputs = {m, nhwcArr}; + } else if(epi.mode == Epilogue::BNAct) { + outName = outSuffix("wino_output_untransform", outCfg.wpt) + "_bnact_a" + std::to_string(epi.act); + outSrc = kWinoOutputSourceBNAct; + outInputNames = {"m", "nhwc", "scale", "bias"}; + outInputs = {m, nhwcArr, *epi.scale, *epi.bias}; + outTpl.push_back({"ACT", epi.act}); + } else if(epi.mode == Epilogue::BiasBNAct) { + outName = outSuffix("wino_output_untransform", outCfg.wpt) + "_biasbnact_a" + std::to_string(epi.act); + outSrc = kWinoOutputSourceBiasBNAct; + outInputNames = {"m", "nhwc", "gbias", "scale", "bias"}; + outInputs = {m, nhwcArr, *epi.gbias, *epi.scale, *epi.bias}; + outTpl.push_back({"ACT", epi.act}); + } else { // Residual + outName = outSuffix("wino_output_untransform", outCfg.wpt) + "_resid"; + outSrc = kWinoOutputSourceResidual; + outInputNames = {"m", "nhwc", "resid"}; + outInputs = {m, nhwcArr, *epi.resid}; + } + + auto outFn = mx::fast::metal_kernel( + outName.c_str(), outInputNames, /*output_names=*/{"outp"}, outSrc); + auto outOuts = outFn( + outInputs, + /*output_shapes=*/{ mx::Shape{N, H, W, Cout} }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX_out, gridY_out, 1), + /*threadgroup=*/std::make_tuple(outCfg.tg0, outCfg.tg1, 1), + /*template_args=*/outTpl, + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + return outOuts[0]; +} + +} // namespace MLXWinograd + +#endif // USE_MLX_BACKEND +#endif // NEURALNET_MLXWINOGRAD_H_ diff --git a/cpp/neuralnet/mlxwinotuner.cpp b/cpp/neuralnet/mlxwinotuner.cpp new file mode 100644 index 000000000..363e2f300 --- /dev/null +++ b/cpp/neuralnet/mlxwinotuner.cpp @@ -0,0 +1,1308 @@ +#ifdef USE_MLX_BACKEND + +#include "../neuralnet/mlxwinotuner.h" +#include "../neuralnet/desc.h" +#include "../neuralnet/greedysearch.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // sysctlbyname, for detectGpuName() +#include // getpid(), for atomic temp-file save + +#include "../core/fileutils.h" +#include "../core/global.h" +#include "../core/logger.h" +#include "../core/makedir.h" +#include "../dataio/homedata.h" + +#include "mlx/mlx.h" +#include "mlx/fast.h" +#include +#include +#include + +using namespace std; + +static const int MLX_WINO_TUNER_VERSION = 3; +static const std::string MLX_WINO_TUNEPARAMS_VERSION_LINE = + "VERSION=" + std::to_string(MLX_WINO_TUNER_VERSION); + +// Mirrors OpenCLTuner's readDescKeyValues: parse "KEY=VALUE KEY=VALUE ..." line into a map. +static map parseKeyValueLine(const string& fileName, const string& line) { + map kvs; + vector tokens = Global::split(line); + for(const string& tok : tokens) { + size_t eq = tok.find('='); + if(eq == string::npos) + throw IOError("MLXWinogradTuneParams: token without '=' in " + fileName + " line: " + line); + string k = tok.substr(0, eq); + string v = tok.substr(eq + 1); + if(k.empty()) + throw IOError("MLXWinogradTuneParams: key-value pair without key in " + fileName + " line: " + line); + if(v.empty()) + throw IOError("MLXWinogradTuneParams: key-value pair without value for key '" + k + "' in " + fileName + " line: " + line); + if(kvs.count(k) > 0) + throw IOError("MLXWinogradTuneParams: duplicate key " + k + " in " + fileName); + try { + kvs[k] = Global::stringToInt(v); + } catch(const StringError&) { + throw IOError("MLXWinogradTuneParams: could not parse value for key " + k + " in " + fileName); + } + } + return kvs; +} + +static int requireKey(const map& kvs, const string& key, const string& fileName) { + auto it = kvs.find(key); + if(it == kvs.end()) + throw IOError("MLXWinogradTuneParams: missing key " + key + " in " + fileName); + return it->second; +} + +bool MLXWinogradTuneParams::isValid() const { + if(inputTransform.tg0 <= 0 || inputTransform.tg1 <= 0) return false; + if(outputUntransform.tg0 <= 0 || outputUntransform.tg1 <= 0) return false; + // Bound each threadgroup dim before multiplying. These values come from the + // cache file via an unchecked int parse; a corrupt large pair (e.g. + // 46341*46341) would overflow the int product below (UB) and could wrap to a + // small value that slips past the > 1024 gate. A single Metal threadgroup + // dim can't exceed 1024 anyway, so cap each first. + if(inputTransform.tg0 > 1024 || inputTransform.tg1 > 1024) return false; + if(outputUntransform.tg0 > 1024 || outputUntransform.tg1 > 1024) return false; + if(inputTransform.tg0 * inputTransform.tg1 > 1024) return false; + if(outputUntransform.tg0 * outputUntransform.tg1 > 1024) return false; + // Upper-bound wpt/vw too (also unchecked cache-file ints): the kernels support + // wpt in {1,2,4,8} and vw in {1,2,4}; a corrupt larger value would otherwise + // validate and run a pathological geometry. + if(inputTransform.wpt < 1 || inputTransform.wpt > 8) return false; + if(outputUntransform.wpt < 1 || outputUntransform.wpt > 8) return false; + if(inputTransform.vw < 1 || inputTransform.vw > 4) return false; + // gridOrder is cast from a cache-file int with no range check; reject any + // value outside the defined enum so a corrupt cache re-tunes instead of + // running an unintended (possibly VW-invalid) geometry as if it were Tfast. + if(inputTransform.gridOrder != MLXWinograd::GridOrder::Cfast + && inputTransform.gridOrder != MLXWinograd::GridOrder::Tfast) return false; + // Tfast (GRID_ORDER=1) requires VW=1 in the kernels. Reject any input + // candidate that violates this — surfaces the constraint earlier than + // the Metal JIT static_assert. (Output VW is gone; global gridOrder + // is gone; input gridOrder stands alone.) + if(inputTransform.gridOrder == MLXWinograd::GridOrder::Tfast + && inputTransform.vw != 1) return false; + return true; +} + +void MLXWinogradTuneParams::save(const string& filename, const MLXWinogradTuneParams& params) { + // Write to a per-process-unique temp path, then atomically rename onto the + // final path. This prevents two katago processes that both cache-miss on the + // same model and tune concurrently from tearing the shared cache file. + const string tmpPath = filename + ".tmp." + std::to_string((long)getpid()); + ofstream out; + FileUtils::open(out, tmpPath); + out << MLX_WINO_TUNEPARAMS_VERSION_LINE << "\n"; + out << "#inputTransform\n"; + out << "tg0=" << params.inputTransform.tg0 + << " tg1=" << params.inputTransform.tg1 + << " wpt=" << params.inputTransform.wpt + << " vw=" << params.inputTransform.vw + << " gridOrder=" << (int)params.inputTransform.gridOrder << "\n"; + out << "#outputUntransform\n"; + out << "tg0=" << params.outputUntransform.tg0 + << " tg1=" << params.outputUntransform.tg1 + << " wpt=" << params.outputUntransform.wpt << "\n"; + out.flush(); + out.close(); + // Atomic publish: only fully-written content ever appears at `filename`. + FileUtils::rename(tmpPath, filename); +} + +MLXWinogradTuneParams MLXWinogradTuneParams::load(const string& filename) { + vector raw = FileUtils::readFileLines(filename, '\n'); + vector lines; + for(const string& r : raw) { + string s = Global::stripComments(r); + s = Global::trim(s); + if(!s.empty()) lines.push_back(s); + } + if(lines.empty()) + throw IOError("MLXWinogradTuneParams::load: no content in " + filename); + if(lines[0] != MLX_WINO_TUNEPARAMS_VERSION_LINE) + throw IOError("MLXWinogradTuneParams::load: expected first line to be " + + MLX_WINO_TUNEPARAMS_VERSION_LINE + " in " + filename); + if(lines.size() != 3) + throw IOError("MLXWinogradTuneParams::load: expected 3 non-comment lines in " + filename); + + MLXWinogradTuneParams params; + { + map kvs = parseKeyValueLine(filename, lines[1]); + params.inputTransform.tg0 = requireKey(kvs, "tg0", filename); + params.inputTransform.tg1 = requireKey(kvs, "tg1", filename); + params.inputTransform.wpt = requireKey(kvs, "wpt", filename); + params.inputTransform.vw = requireKey(kvs, "vw", filename); + params.inputTransform.gridOrder = (MLXWinograd::GridOrder)requireKey(kvs, "gridOrder", filename); + } + { + map kvs = parseKeyValueLine(filename, lines[2]); + params.outputUntransform.tg0 = requireKey(kvs, "tg0", filename); + params.outputUntransform.tg1 = requireKey(kvs, "tg1", filename); + params.outputUntransform.wpt = requireKey(kvs, "wpt", filename); + } + return params; +} + +string MLXWinogradTuner::defaultDirectory(bool makeDir, const string& homeDataDirOverride) { + string dir = HomeData::getHomeDataDir(makeDir, homeDataDirOverride); + dir += "/mlxwinotuning"; + if(makeDir) MakeDir::make(dir); + return dir; +} + +string MLXWinogradTuner::defaultFileName(const string& gpuName, + int nnXLen, int nnYLen, + int trunkNumChannels, int modelVersion, + bool useFP16, bool full) { + string clean; + for(char c : gpuName) { + if((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) + clean += c; + } + const char* dtypeSuffix = useFP16 ? "_fp16" : "_fp32"; + // The full (wide-grid) and fast (coarse-grid) tunes produce different winners + // and must NOT share a cache file, otherwise switching the UI Fast/Full mode + // would silently keep loading the other mode's cached params. The fast tune + // keeps the legacy name (no suffix) so existing on-device caches still hit; + // the full tune gets a distinct "_full" file. Both coexist per device/model. + const char* modeSuffix = full ? "_full" : ""; + return Global::strprintf("tunemlxwino%d_gpu%s_x%d_y%d_c%d_mv%d%s%s.txt", + MLX_WINO_TUNER_VERSION, clean.c_str(), + nnXLen, nnYLen, trunkNumChannels, modelVersion, + dtypeSuffix, modeSuffix); +} + +string MLXWinogradTuner::detectGpuName() { + // The optimal Winograd launch geometry differs across Apple GPU variants, so + // the cache key must distinguish them; otherwise a cache tuned on one chip + // (e.g. M1) would be loaded verbatim on another (e.g. M4 Max). MLX does not + // reliably export a device name (mlx::core::metal::device_info() is declared + // but not exported in all libmlx builds), so query the chip brand string + // directly. On Apple Silicon this returns e.g. "Apple M3 Max"; + // defaultFileName() sanitizes it to [A-Za-z0-9]. + char buf[128]; + size_t len = sizeof(buf); + if(sysctlbyname("machdep.cpu.brand_string", buf, &len, nullptr, 0) == 0 && len > 1) { + buf[sizeof(buf) - 1] = '\0'; // guarantee NUL-termination + string name(buf); // stops at the first NUL + if(!name.empty()) + return name; + } + return "AppleSilicon"; +} + +namespace mx = mlx::core; + +namespace { + +// One stage-1 (input transform) timed run on a synthetic [N,H,W,C] tensor. +// Mirrors the inner-loop shape of winogradConv2d's stage 1, but issues only +// the input-transform kernel so we can score it in isolation. Returns wall ms. +// Input kernel always writes Std layout (matmulOrient axis is gone). +static double timeOneInputTransform( + const MLXWinograd::InputTransform& cfg, + const mx::array& input, int channels, + bool useFP16, bool doWarmup) { + int N = input.shape(0); + int H = input.shape(1); + int W = input.shape(2); + int tilesY = (H + 1) / 2; + int tilesX = (W + 1) / 2; + int Ntiles = N * tilesY * tilesX; + + const mx::Dtype dtype = useFP16 ? mx::float16 : mx::float32; + + // Kernel name encodes the still-live axes so the Metal JIT cache sees a + // unique entry per (dtype, wpt, vw, gridOrder) combination. + std::string kernelName = + std::string(useFP16 ? "wino_input_transform_f16" : "wino_input_transform_f32") + + "_w" + std::to_string(cfg.wpt) + + "_v" + std::to_string(cfg.vw) + + "_g" + std::to_string((int)cfg.gridOrder) + + "_tune"; + + auto fn = mx::fast::metal_kernel( + kernelName.c_str(), + /*input_names=*/{"inp"}, + /*output_names=*/{"outp"}, + /*source=*/MLXWinograd::kWinoInputSource); + + // Output shape: [16, Ntiles, C] (Std only). + mx::Shape outShape = {16, Ntiles, channels}; + + // Grid depends on gridOrder: Cfast → (ceil(C/vw), ceil(Ntiles/wpt), 1), + // Tfast → (Ntiles, ceil(C/wpt), 1). + int gridX = (cfg.gridOrder == MLXWinograd::GridOrder::Cfast) + ? ((channels + cfg.vw - 1) / cfg.vw) + : Ntiles; + int gridY = (cfg.gridOrder == MLXWinograd::GridOrder::Cfast) + ? ((Ntiles + cfg.wpt - 1) / cfg.wpt) + : ((channels + cfg.wpt - 1) / cfg.wpt); + + std::vector> tmplArgs = { + {"T", dtype}, + {"WPT", cfg.wpt}, + {"VW", cfg.vw}, + {"GRID_ORDER", (int)cfg.gridOrder} + }; + + // Untimed warmup (gated to the first measured rep per shape): hots + // pipeline-state + lazy-graph caches for THIS config before the timed eval. + // Caller gates so we don't re-warm on every rep. + if(doWarmup) { + auto warmOuts = fn( + /*inputs=*/{input}, + /*output_shapes=*/{ outShape }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX, gridY, 1), + /*threadgroup=*/std::make_tuple(cfg.tg0, cfg.tg1, 1), + /*template_args=*/tmplArgs, + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + mx::eval(warmOuts[0]); + } + + // Timed pass — build fresh lazy node and eval it. + auto outs = fn( + /*inputs=*/{input}, + /*output_shapes=*/{ outShape }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX, gridY, 1), + /*threadgroup=*/std::make_tuple(cfg.tg0, cfg.tg1, 1), + /*template_args=*/tmplArgs, + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + auto t0 = std::chrono::steady_clock::now(); + mx::eval(outs[0]); + auto t1 = std::chrono::steady_clock::now(); + return std::chrono::duration(t1 - t0).count(); +} + +// Same shape for output untransform: synthetic [16, Ntiles, outC] -> [N,H,W,outC]. +// m is always Std-layout ([16, Ntiles, outC]). +static double timeOneOutputUntransform( + const MLXWinograd::OutputUntransform& cfg, + const mx::array& m, int N, int H, int W, int outC, + bool useFP16, bool doWarmup) { + int tilesY = (H + 1) / 2; + int tilesX = (W + 1) / 2; + int Ntiles = N * tilesY * tilesX; + + int nhwc_arr[4] = {N, H, W, outC}; + mx::array nhwcArr(nhwc_arr, {4}, mx::int32); + + const mx::Dtype dtype = useFP16 ? mx::float16 : mx::float32; + + // Kernel name encodes the still-live axes so the Metal JIT cache sees a + // unique entry per (dtype, wpt) combination. (Output kernel is VW=1 + // monomorphic, Cfast monomorphic, and Std-only.) + std::string kernelName = + std::string(useFP16 ? "wino_output_untransform_f16" : "wino_output_untransform_f32") + + "_w" + std::to_string(cfg.wpt) + + "_tune"; + + auto fn = mx::fast::metal_kernel( + kernelName.c_str(), + /*input_names=*/{"m", "nhwc"}, + /*output_names=*/{"outp"}, + /*source=*/MLXWinograd::kWinoOutputSource); + + // Cfast-only grid: (outC, ceil(Ntiles/wpt), 1). + int gridX = outC; + int gridY = (Ntiles + cfg.wpt - 1) / cfg.wpt; + + std::vector> tmplArgs = { + {"T", dtype}, + {"WPT", cfg.wpt} + }; + + // Untimed warmup (gated to the first measured rep per shape): hots + // pipeline-state + lazy-graph caches for THIS config before the timed eval. + // Caller gates so we don't re-warm on every rep. + if(doWarmup) { + auto warmOuts = fn( + /*inputs=*/{m, nhwcArr}, + /*output_shapes=*/{ mx::Shape{N, H, W, outC} }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX, gridY, 1), + /*threadgroup=*/std::make_tuple(cfg.tg0, cfg.tg1, 1), + /*template_args=*/tmplArgs, + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + mx::eval(warmOuts[0]); + } + + // Timed pass — build fresh lazy node and eval it. + auto outs = fn( + /*inputs=*/{m, nhwcArr}, + /*output_shapes=*/{ mx::Shape{N, H, W, outC} }, + /*output_dtypes=*/{ dtype }, + /*grid=*/std::make_tuple(gridX, gridY, 1), + /*threadgroup=*/std::make_tuple(cfg.tg0, cfg.tg1, 1), + /*template_args=*/tmplArgs, + /*init_value=*/std::nullopt, + /*verbose=*/false, + /*stream=*/mx::StreamOrDevice{}); + auto t0 = std::chrono::steady_clock::now(); + mx::eval(outs[0]); + auto t1 = std::chrono::steady_clock::now(); + return std::chrono::duration(t1 - t0).count(); +} + +// Random NHWC input tensor for the input-transform timing harness. +// When useFP16, astype the fp32 source to fp16 so the timed kernel measures +// the active precision. +static mx::array makeRandomInput(int N, int H, int W, int C, uint32_t seed, bool useFP16) { + std::vector v((size_t)N * H * W * C); + std::mt19937 rng(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : v) x = dist(rng); + mx::array arr(v.data(), {N, H, W, C}, mx::float32); + if(useFP16) return mx::astype(arr, mx::float16); + return arr; +} + +// Random [16, Ntiles, outC] tensor for the output-untransform timing harness. +// When useFP16, astype the fp32 source to fp16 so the timed kernel measures +// the active precision. +static mx::array makeRandomMatmulOut(int Ntiles, int outC, uint32_t seed, bool useFP16) { + std::vector v((size_t)16 * Ntiles * outC); + std::mt19937 rng(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + for(auto& x : v) x = dist(rng); + mx::array arr(v.data(), {16, Ntiles, outC}, mx::float32); + if(useFP16) return mx::astype(arr, mx::float16); + return arr; +} + +// Forward decl: planShapeRotation is defined further down in this anonymous +// namespace alongside its policy constants, but the scoring functions above +// reference it. Pure function; safe to forward-declare. +static std::vector +planShapeRotation(const std::vector>& histogram, bool full); + +// Score one input-transform candidate. Adaptive rotation over the model's +// actual 3x3 conv input-channel distribution: planShapeRotation produces a +// list of (channels, measureReps, weight) entries; per shape we time +// `measureReps` reps and take the median, weighted into the final score by +// `weight`. Each shape warms once on its first measured rep (gated via the +// doWarmup arg to timeOneInputTransform); subsequent reps skip the warmup. +static double scoreInputTransform(const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full) { + auto plan = planShapeRotation(mi.conv3x3InputHistogram, full); + assert(!plan.empty()); + + // Pre-build one random input array per planned shape. Each shape warms once + // on its first measured rep (gated via doWarmup), so no separate warmup pass. + std::vector inputs; + inputs.reserve(plan.size()); + uint32_t seed = 0xA1A1A1A1u; + for(const auto& sp : plan) { + inputs.push_back(makeRandomInput(N, H, W, sp.channels, seed, useFP16)); + mx::eval(inputs.back()); + seed = seed * 1664525u + 1013904223u; // distinct seed per shape + } + + double score = 0.0; + for(size_t i = 0; i < plan.size(); i++) { + std::vector samples; + samples.reserve(plan[i].measureReps); + for(int r = 0; r < plan[i].measureReps; r++) { + double ms = timeOneInputTransform(cfg, inputs[i], plan[i].channels, useFP16, /*doWarmup=*/(r == 0)); + samples.push_back(ms); + } + // Median (upper of two middles for even sizes; identical to nth_element + // at index size/2). + std::nth_element(samples.begin(), + samples.begin() + samples.size() / 2, + samples.end()); + double median = samples[samples.size() / 2]; + // A non-finite measurement (nan/inf from a failed/pathological kernel run) + // must NEVER win selection. The tuner minimizes time, so map it to +inf so + // this candidate loses every comparison rather than mapping to 0 (best). + if(!std::isfinite(median)) median = std::numeric_limits::infinity(); + score += plan[i].weight * median; + } + return score; +} + +// Score one output-untransform candidate. Symmetric to scoreInputTransform: +// adaptive rotation over the model's 3x3 conv output-channel distribution. +static double scoreOutputUntransform(const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full) { + int tilesY = (H + 1) / 2; + int tilesX = (W + 1) / 2; + int Ntiles = N * tilesY * tilesX; + + auto plan = planShapeRotation(mi.conv3x3OutputHistogram, full); + assert(!plan.empty()); + + std::vector matmulOuts; + matmulOuts.reserve(plan.size()); + uint32_t seed = 0xD4D4D4D4u; + for(const auto& sp : plan) { + matmulOuts.push_back(makeRandomMatmulOut(Ntiles, sp.channels, seed, useFP16)); + mx::eval(matmulOuts.back()); + seed = seed * 1664525u + 1013904223u; + } + + double score = 0.0; + for(size_t i = 0; i < plan.size(); i++) { + std::vector samples; + samples.reserve(plan[i].measureReps); + for(int r = 0; r < plan[i].measureReps; r++) { + double ms = timeOneOutputUntransform(cfg, matmulOuts[i], N, H, W, + plan[i].channels, useFP16, /*doWarmup=*/(r == 0)); + samples.push_back(ms); + } + std::nth_element(samples.begin(), + samples.begin() + samples.size() / 2, + samples.end()); + double median = samples[samples.size() / 2]; + // A non-finite measurement (nan/inf from a failed/pathological kernel run) + // must NEVER win selection. The tuner minimizes time, so map it to +inf so + // this candidate loses every comparison rather than mapping to 0 (best). + if(!std::isfinite(median)) median = std::numeric_limits::infinity(); + score += plan[i].weight * median; + } + return score; +} + +// Selection-and-allocation policy for the work-weighted shape rotation. +// Pure function. Inputs: list of (channels, occurrence_count) pairs from the +// model's 3x3 conv distribution. Output: vector sorted desc by +// weight, with Σ measureReps == the active rep budget and Σ weight ≈ 1.0. +// +// Selection-rule constants. The per-candidate timing budget depends on the +// sweep breadth: +// - full=true (operator `./katago tuner -full`, the wide grid): 19 timed +// reps, 3-rep floor per minor shape. Score the wide grid carefully. +// - full=false (the per-model-load AUTO tune, the coarse grid): 7 timed reps, +// 2-rep floor. This is the path that runs on every first launch. The winning +// Winograd geometry sits on a broad plateau — many configs land within ~7% +// of each other and end-to-end throughput moves <=1.5% across the whole +// plateau — so a median over 7 reps lands on the plateau just as reliably as +// 19. It only loosens the noise tie-break between near-equivalent configs, +// which is exactly the part that doesn't affect throughput. Dropping the +// per-candidate eval count (1 warmup + reps) from 20 to 8 makes the model- +// load sweep ~2.5x faster. +static constexpr int kMeasureRepsFull = 19; +static constexpr int kMeasureRepsCoarse = 7; +static constexpr int kRepFloorFull = 3; +static constexpr int kRepFloorCoarse = 2; +static constexpr size_t kMaxShapes = 3; +static constexpr double kWorkFractionFloor = 0.03; + +static std::vector +planShapeRotation(const std::vector>& histogram, bool full) { + // Active rep budget: precise for the operator wide-grid tune, fast for the + // per-model-load coarse tune. See the constant block above for the rationale. + const int kMeasureReps = full ? kMeasureRepsFull : kMeasureRepsCoarse; + const int kRepFloor = full ? kRepFloorFull : kRepFloorCoarse; + + // Degenerate case: empty histogram is a model-corruption signal we + // surface, not silently mask. + assert(!histogram.empty()); + + // Step 1: compute work = count * channels; sort desc by work; take top-K. + struct Entry { int channels; long long work; }; + std::vector entries; + entries.reserve(histogram.size()); + for(const auto& [c, n] : histogram) { + if(c <= 0 || n <= 0) continue; + entries.push_back({c, static_cast(c) * static_cast(n)}); + } + assert(!entries.empty()); + + std::sort(entries.begin(), entries.end(), + [](const Entry& a, const Entry& b) { + if(a.work != b.work) return a.work > b.work; + return a.channels > b.channels; // tie-break: larger C first + }); + if(entries.size() > kMaxShapes) + entries.resize(kMaxShapes); + + // Step 2: threshold against post-top-K total work; recompute total. + long long totalWork = 0; + for(const auto& e : entries) totalWork += e.work; + assert(totalWork > 0); + entries.erase( + std::remove_if(entries.begin(), entries.end(), + [totalWork](const Entry& e) { + return static_cast(e.work) / static_cast(totalWork) + < kWorkFractionFloor; + }), + entries.end()); + // Dominant survives (it's the largest; if its share < 3% then total plan; + plan.reserve(entries.size()); + for(const auto& e : entries) { + MLXWinogradTuner::ShapePlan sp; + sp.channels = e.channels; + sp.weight = static_cast(e.work) / static_cast(totalWork); + sp.measureReps = 0; // assigned below + plan.push_back(sp); + } + + // Step 4: allocate kMeasureReps with floor. + if(plan.size() == 1) { + plan[0].measureReps = kMeasureReps; + return plan; + } + + // Tentative round-to-nearest allocation. + for(auto& sp : plan) { + sp.measureReps = static_cast(std::lround(sp.weight * kMeasureReps)); + } + + // Floor-bump: any minor shape below kRepFloor gets bumped, deficit out of dominant. + for(size_t i = 1; i < plan.size(); i++) { + if(plan[i].measureReps < kRepFloor) { + int deficit = kRepFloor - plan[i].measureReps; + plan[i].measureReps += deficit; + plan[0].measureReps -= deficit; + } + } + + // Rounding repair: dominant absorbs +/-1 so Σ == kMeasureReps. + int sum = 0; + for(const auto& sp : plan) sum += sp.measureReps; + plan[0].measureReps += (kMeasureReps - sum); + + // Final invariants. The dominant-underflow assert here can fire only if the + // budget can't cover kMaxShapes floors (kMeasureReps < kMaxShapes*kRepFloor); + // both budgets satisfy that (full: 19>=3*3; coarse: 7>=3*2) so it can't fire. + assert(plan[0].measureReps >= kRepFloor); +#ifndef NDEBUG + int finalSum = 0; + for(const auto& sp : plan) finalSum += sp.measureReps; + assert(finalSum == kMeasureReps); +#endif + + return plan; +} + +// Per-shape median timing for diagnostic logging. Same rotation/plan as the +// scoring functions; reports one (channels, median_ms) entry per planned +// shape instead of a single weighted score. Used by the flat-sweep log's +// "shape_ms=" field and the gated per-shape consistency test. + +static std::vector> +scoreInputTransformPerShape(const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full) { + auto plan = planShapeRotation(mi.conv3x3InputHistogram, full); + assert(!plan.empty()); + + std::vector inputs; + inputs.reserve(plan.size()); + uint32_t seed = 0xA1A1A1A1u; + for(const auto& sp : plan) { + inputs.push_back(makeRandomInput(N, H, W, sp.channels, seed, useFP16)); + mx::eval(inputs.back()); + seed = seed * 1664525u + 1013904223u; + } + + std::vector> out; + out.reserve(plan.size()); + for(size_t i = 0; i < plan.size(); i++) { + std::vector samples; + samples.reserve(plan[i].measureReps); + for(int r = 0; r < plan[i].measureReps; r++) { + samples.push_back( + timeOneInputTransform(cfg, inputs[i], plan[i].channels, useFP16, /*doWarmup=*/(r == 0))); + } + std::nth_element(samples.begin(), + samples.begin() + samples.size() / 2, + samples.end()); + double median = samples[samples.size() / 2]; + if(!std::isfinite(median)) median = 0.0; + out.emplace_back(plan[i].channels, median); + } + return out; +} + +static std::vector> +scoreOutputUntransformPerShape(const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full) { + int Ntiles = N * ((H + 1) / 2) * ((W + 1) / 2); + + auto plan = planShapeRotation(mi.conv3x3OutputHistogram, full); + assert(!plan.empty()); + + std::vector matmulOuts; + matmulOuts.reserve(plan.size()); + uint32_t seed = 0xD4D4D4D4u; + for(const auto& sp : plan) { + matmulOuts.push_back(makeRandomMatmulOut(Ntiles, sp.channels, seed, useFP16)); + mx::eval(matmulOuts.back()); + seed = seed * 1664525u + 1013904223u; + } + + std::vector> out; + out.reserve(plan.size()); + for(size_t i = 0; i < plan.size(); i++) { + std::vector samples; + samples.reserve(plan[i].measureReps); + for(int r = 0; r < plan[i].measureReps; r++) { + samples.push_back( + timeOneOutputUntransform(cfg, matmulOuts[i], N, H, W, + plan[i].channels, useFP16, /*doWarmup=*/(r == 0))); + } + std::nth_element(samples.begin(), + samples.begin() + samples.size() / 2, + samples.end()); + double median = samples[samples.size() / 2]; + if(!std::isfinite(median)) median = 0.0; + out.emplace_back(plan[i].channels, median); + } + return out; +} + +// Candidate axis value sets, in two breadths that mirror OpenCL's tuner: +// +// full=false (default; the model-load AUTO-tune): a COARSE grid of a few +// representative threadgroup / work-per-thread points. Measured on this +// hardware, the winning configs form a broad plateau — many configs land +// within ~7% of each other and ~25-40% above the baked default — and +// geometry moves end-to-end throughput <=1.5%. So a coarse sweep finds the +// plateau in ~2s instead of the wide grid's ~16s, which otherwise burns +// that time discriminating between near-equivalent (and run-to-run noisy) +// configs. +// +// full=true (a deliberate command tune via `./katago tuner -full`): +// the wide grid, for operators who want to squeeze the plateau. Both +// backends pin full=false at model load (openclbackend.cpp / +// mlxbackend.cpp) and reach the wide grid only through the explicit +// tuner command. +// Coarse (model-load) tg sets drop only the extreme threadgroup dims relative +// to a uniform {8,16,32,64,128}×{1,2,4,8,16} sweep: tg0=8 (smallest, rarely the +// occupancy sweet spot for these tiny transform kernels) and tg1=16 (largest; +// pairs with large tg0 to exceed 1024 anyway). The baked default {tg0=32,tg1=1} +// stays in the set, and the surviving points still bracket the full threadgroup- +// size range, so the sweep stays on the broad plateau (see the rep-budget +// comment) while measuring ~1.5x fewer configs. The wide grid (full=true) keeps +// every point for the operator `tuner -full` path. +static const std::vector& inputTg0Values(bool full) { + static const std::vector vFull = {1,2,4,8,16,24,32,48,64,96,128,160,192,256,384,512,1024}; + static const std::vector vCoarse = {16,32,64,128}; + return full ? vFull : vCoarse; +} +static const std::vector& inputTg1Values(bool full) { + static const std::vector vFull = {1,2,4,5,8,10,16,20,25,32,40,50,64,100,128}; + static const std::vector vCoarse = {1,2,4,8}; + return full ? vFull : vCoarse; +} +static const std::vector& outputTg0Values(bool full) { + // Mirror input set — treat tg0 symmetrically. + static const std::vector vFull = {1,2,4,8,16,24,32,48,64,96,128,160,192,256,384,512,1024}; + static const std::vector vCoarse = {16,32,64,128}; + return full ? vFull : vCoarse; +} +static const std::vector& outputTg1Values(bool full) { + static const std::vector vFull = {1,2,4,5,8,10,16,20,25,32,40,50,64,100,128}; + static const std::vector vCoarse = {1,2,4,8}; + return full ? vFull : vCoarse; +} + +// wptValues() is used by both stages; vwValues() is input-only +// (output kernel is VW=1 monomorphic). wpt narrows under the coarse auto grid +// too — the wpt=8 tail rarely wins for these tiny transform kernels. vw has +// only three values, so coarse == full there. +static const std::vector& wptValues(bool full) { + static const std::vector vFull = {1, 2, 4, 8}; + static const std::vector vCoarse = {1, 2, 4}; + return full ? vFull : vCoarse; +} +static const std::vector& vwValues() { + static const std::vector v = {1, 2, 4}; + return v; +} + +// Returns true iff (tg0, tg1, wpt, vw, gridOrder) is structurally valid +// AND vw divides the fast-axis dim of the current stage shape. +static bool isInputCandidateValid(int tg0, int tg1, int wpt, int vw, + MLXWinograd::GridOrder go, + int C, int /*Ntiles*/) { + if(tg0 <= 0 || tg1 <= 0 || wpt <= 0 || vw <= 0) return false; + if(tg0 * tg1 > 1024) return false; + if(go == MLXWinograd::GridOrder::Cfast) { + if(vw > 1 && (C % vw) != 0) return false; + } else { + // Tfast: vw must be 1 (kernel static_assert enforces this). + if(vw != 1) return false; + } + return true; +} +// Output kernel is VW=1 monomorphic — no vw parameter, no +// vw-divisibility check on outC. Output kernel is also Cfast monomorphic +// — no gridOrder parameter. +static bool isOutputCandidateValid(int tg0, int tg1, int wpt, + int /*outC*/, int /*Ntiles*/) { + if(tg0 <= 0 || tg1 <= 0 || wpt <= 0) return false; + if(tg0 * tg1 > 1024) return false; + return true; +} + +static std::vector +buildInputCandidates(bool full, int C, int Ntiles, MLXWinograd::GridOrder go) { + std::vector out; + for(int tg0 : inputTg0Values(full)) + for(int tg1 : inputTg1Values(full)) + for(int wpt : wptValues(full)) + for(int vw : vwValues()) { + if(!isInputCandidateValid(tg0, tg1, wpt, vw, go, C, Ntiles)) continue; + out.push_back({tg0, tg1, wpt, vw, go}); + } + return out; +} +static std::vector +buildOutputCandidates(bool full, int outC, int Ntiles) { + std::vector out; + for(int tg0 : outputTg0Values(full)) + for(int tg1 : outputTg1Values(full)) + for(int wpt : wptValues(full)) { + if(!isOutputCandidateValid(tg0, tg1, wpt, outC, Ntiles)) continue; + out.push_back({tg0, tg1, wpt}); + } + return out; +} + +// Renders the per-shape median-timing suffix " shape_ms=c:,..." shared +// by both flat sweeps. An empty vector yields an empty string. +static std::string renderPerShapeMs(const std::vector>& perShape) { + std::string s = " shape_ms="; + for(size_t i = 0; i < perShape.size(); i++) { + if(i > 0) s += ","; + s += "c" + std::to_string(perShape[i].first) + + ":" + Global::strprintf("%.3f", perShape[i].second); + } + return s; +} + +// Writes the two-line sweep diagnostic shared by flatSweepInput/flatSweepOutput: +// an optional skipped-count line, then the considered/best/baseline/delta/per-shape +// summary. `label` is the sweep name (e.g. "flatSweepInput"); `bestFields` is the +// caller-rendered "tg0=.. tg1=.. .." body (input adds vw/gridOrder, output omits +// them) and is unused when haveBest is false; `perShapeStr` is the renderPerShapeMs +// suffix or empty. delta_pct is computed here on the same (haveBest && baseline>=1e-9) +// condition the callers use to build perShapeStr, so the "nan" degenerate branch +// stays in lockstep. This function owns the regex-pinned log format (see mlxtests.cpp). +static void logFlatSweep( + Logger* logger, const char* label, int considered, int skipped, + bool haveBest, const std::string& bestFields, double bestTime, + double baselineMs, const std::string& perShapeStr) { + if(logger == nullptr) return; + if(skipped > 0) + logger->write(std::string("MLX tuner ") + label + " skipped=" + std::to_string(skipped) + + " candidate(s) that failed to score; kept best valid config"); + // %+.1f always emits a sign; the gated log-format test regex relies on this + // (matches [-+], not [-+]?). Don't drop the + flag. + const std::string deltaStr = (haveBest && baselineMs >= 1e-9) + ? Global::strprintf("%+.1f", (bestTime - baselineMs) / baselineMs * 100.0) + : std::string("nan"); + logger->write(std::string("MLX tuner ") + label + ": considered=" + std::to_string(considered) + + (haveBest + ? " best=" + bestFields + " time_ms=" + Global::strprintf("%.3f", bestTime) + : " best=none") + + " baseline_ms=" + Global::strprintf("%.3f", baselineMs) + + " delta_pct=" + deltaStr + + perShapeStr); +} + +// Flat sweep over (tg0, tg1, wpt, vw, gridOrder) for the input transform. +// Returns the best (lowest-time) +// candidate that passes isInputCandidateValid; nullopt if no candidate is +// valid (defensive -- should not happen for a real model). +static std::optional +flatSweepInput(int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full, bool useGreedy, Logger* logger, int* consideredOut) { + using GO = MLXWinograd::GridOrder; + // Candidate enumeration's vw-divisibility filter uses C as the most + // restrictive channel count the kernel will encounter. Use the max of the + // model's actual 3x3 input channel distribution. + int C = 0; + for(const auto& p : mi.conv3x3InputHistogram) C = std::max(C, p.first); + assert(C > 0); + const int tilesY = (H + 1) / 2; + const int tilesX = (W + 1) / 2; + const int Ntiles = N * tilesY * tilesX; + + // Score the baked default (default-constructed = {tg0=32, tg1=1, wpt=1, + // vw=1, gridOrder=Cfast}) so the sweep log carries a baseline the operator + // can compare the winner against. Always adopted-winner; no fallback. + // The defaults satisfy isInputCandidateValid for any (C, Ntiles) because + // vw=1 divides every channel count; see mlxwinograd.h for the struct defaults. + const double baselineMs = + scoreInputTransform(MLXWinograd::InputTransform{}, N, H, W, mi, useFP16, full); + + // Seed the floor with the baked default so a sweep in which every candidate + // throws still yields a valid result instead of aborting model load. The + // default ({tg0=32,...}, 32 threads) always passes isInputCandidateValid and + // never exceeds maxTotalThreadsPerThreadgroup, so it scores without throwing. + std::optional best = MLXWinograd::InputTransform{}; + double bestTime = baselineMs; + int considered = 0; + int skipped = 0; + + // The output gridOrder check in isValid() is gone (output kernel is + // Cfast-monomorphic), so the input gridOrder axis can be searched over + // both Cfast and Tfast. The global gridOrder field is also gone — + // input gridOrder stands alone, no cross-stage consistency to enforce. + if(useGreedy) { + // Sensitivity-ordered greedy coordinate descent over the coarse axes. + // Axis 0: tg0, 1: tg1, 2: wpt, 3: (gridOrder,vw) joint — encoding the joint + // axis makes the Tfast->vw=1 coupling a matter of enumeration, not rejection. + const std::vector& tg0v = inputTg0Values(false); + const std::vector& tg1v = inputTg1Values(false); + const std::vector& wptv = wptValues(false); + const std::vector& vwv = vwValues(); + struct GoVw { MLXWinograd::GridOrder go; int vw; }; + std::vector goVw; + for(int vw : vwv) goVw.push_back({MLXWinograd::GridOrder::Cfast, vw}); + goVw.push_back({MLXWinograd::GridOrder::Tfast, 1}); + + const std::vector axisSizes = {(int)tg0v.size(), (int)tg1v.size(), (int)wptv.size(), (int)goVw.size()}; + // Sensitivity order — MEASURED on A15: joint(gridOrder,vw) dominates, then + // tg1 > tg0 > wpt (all ~1-4%, plateau). axis order: joint(3), tg1(1), tg0(0), wpt(2). + const std::vector order = {3, 1, 0, 2}; + // Seed = baked default {tg0=32,tg1=1,wpt=1,(Cfast,1)} as indices, given the + // coarse sets {16,32,64,128}/{1,2,4,8}/{1,2,4}/goVw[0]=(Cfast,1). + // These are indices into the coarse value sets above — update if those sets change. + const std::vector seed = {1, 0, 0, 0}; + + auto decode = [&](const std::vector& idx) { + return MLXWinograd::InputTransform{ tg0v[idx[0]], tg1v[idx[1]], wptv[idx[2]], + goVw[idx[3]].vw, goVw[idx[3]].go }; + }; + // Guard the hard-coded seed indices against silent drift: if someone reorders + // the coarse value sets above, the seed must still decode to the baked default. + { + const MLXWinograd::InputTransform sd = decode(seed), def{}; + if(sd.tg0 != def.tg0 || sd.tg1 != def.tg1 || sd.wpt != def.wpt + || sd.vw != def.vw || sd.gridOrder != def.gridOrder) + throw StringError("MLX winotuner: greedy input seed no longer maps to the baked " + "default; update the seed indices for the reordered value sets"); + } + auto scoreFn = [&](const std::vector& idx) -> double { + MLXWinograd::InputTransform cand = decode(idx); + if(!isInputCandidateValid(cand.tg0, cand.tg1, cand.wpt, cand.vw, cand.gridOrder, C, Ntiles)) + return std::numeric_limits::infinity(); + double t; + try { t = scoreInputTransform(cand, N, H, W, mi, useFP16, full); } + catch(const std::exception&) { return std::numeric_limits::infinity(); } +#ifdef MLX_TUNE_STUDY + std::fprintf(stderr, "[MLX-STUDY] in full=%d go=%d tg0=%d tg1=%d wpt=%d vw=%d score=%.4f\n", + full ? 1 : 0, (int)cand.gridOrder, cand.tg0, cand.tg1, cand.wpt, cand.vw, t); +#endif + return t; + }; + + GreedySearch::Result gr = GreedySearch::coordinateDescent(axisSizes, order, seed, scoreFn, /*maxPasses=*/3); + best = decode(gr.indices); // assign the EXISTING `best` + bestTime = gr.score; // keep the existing logger's delta meaningful + considered = gr.evaluated; // assign the EXISTING `considered` + } else { + for(GO go : {GO::Cfast, GO::Tfast}) { + auto cands = MLXWinogradTuner::buildInputCandidatesForTesting(full, C, Ntiles, go); + for(const auto& cand : cands) { + considered++; + double t; + try { + t = scoreInputTransform(cand, N, H, W, mi, useFP16, full); + } catch(const std::exception&) { + // A candidate whose threadgroup exceeds the pipeline's register-pressure- + // dependent maxTotalThreadsPerThreadgroup (can be < 1024), or that hits a + // transient GPU error, throws out of mx::eval. Skip it; the seeded default + // remains the valid floor. + skipped++; + continue; + } +#ifdef MLX_TUNE_STUDY + std::fprintf(stderr, + "[MLX-STUDY] in full=%d go=%d tg0=%d tg1=%d wpt=%d vw=%d score=%.4f\n", + full ? 1 : 0, (int)cand.gridOrder, cand.tg0, cand.tg1, cand.wpt, cand.vw, t); +#endif + if(t < bestTime) { bestTime = t; best = cand; } + } + } + } + // Render this sweep's best-config fields (input carries vw/gridOrder) and the + // per-shape suffix, then hand off to the shared logger. perShapeStr is built + // only on the same (best && baseline>=1e-9) condition logFlatSweep uses for + // delta_pct, keeping the degenerate best=none / nan branch in lockstep. + std::string bestFields, perShapeStr; + if(best) { + bestFields = "tg0=" + std::to_string(best->tg0) + + " tg1=" + std::to_string(best->tg1) + + " wpt=" + std::to_string(best->wpt) + + " vw=" + std::to_string(best->vw) + + " gridOrder=" + std::to_string((int)best->gridOrder); + if(baselineMs >= 1e-9) { + // Per-shape median timing on the winner — diagnostic only; winner + // selection above used the weighted score from scoreInputTransform. + perShapeStr = renderPerShapeMs(scoreInputTransformPerShape(*best, N, H, W, mi, useFP16, full)); + } + } + logFlatSweep(logger, "flatSweepInput", considered, skipped, + (bool)best, bestFields, bestTime, baselineMs, perShapeStr); + if(consideredOut) *consideredOut = considered; + return best; +} + +// Flat sweep over (tg0, tg1, wpt) for the output untransform. Output VW +// and gridOrder are not searched: the kernel is monomorphic on VW=1 and +// Cfast. +static std::optional +flatSweepOutput(int N, int H, int W, + const MLXWinogradTuner::ModelInfoForTuning& mi, + bool useFP16, bool full, bool useGreedy, Logger* logger, int* consideredOut) { + // Output-untransform candidate enumeration doesn't filter on outC + // (isOutputCandidateValid ignores it — VW=1 monomorphic), but we still + // pass a representative value. Use the max of the model's actual 3x3 + // output distribution. + int outC = 0; + for(const auto& p : mi.conv3x3OutputHistogram) outC = std::max(outC, p.first); + assert(outC > 0); + const int Ntiles = N * ((H + 1) / 2) * ((W + 1) / 2); + + // Score the baked default (default-constructed = {tg0=32, tg1=1, wpt=1}) + // so the sweep log carries a baseline the operator can compare the winner + // against. Symmetric to flatSweepInput. + const double baselineMs = + scoreOutputUntransform(MLXWinograd::OutputUntransform{}, N, H, W, mi, useFP16, full); + + // Seed the floor with the baked default (see flatSweepInput for rationale). + std::optional best = MLXWinograd::OutputUntransform{}; + double bestTime = baselineMs; + int considered = 0; + int skipped = 0; + + // Output kernel is VW=1 monomorphic and Cfast monomorphic, so neither + // VW nor gridOrder is searched here. + if(useGreedy) { + const std::vector& tg0v = outputTg0Values(false); + const std::vector& tg1v = outputTg1Values(false); + const std::vector& wptv = wptValues(false); + const std::vector axisSizes = {(int)tg0v.size(), (int)tg1v.size(), (int)wptv.size()}; + // Sensitivity order — MEASURED on A15: tg0(6%) > tg1(2%) > wpt(1.8%), all plateau. + const std::vector order = {0, 1, 2}; + // Indices into the coarse value sets above — update if those sets change. + const std::vector seed = {1, 0, 0}; // {tg0=32,tg1=1,wpt=1} + // Guard the hard-coded seed against silent drift if the value sets are reordered. + { + const MLXWinograd::OutputUntransform sd{ tg0v[seed[0]], tg1v[seed[1]], wptv[seed[2]] }, def{}; + if(sd.tg0 != def.tg0 || sd.tg1 != def.tg1 || sd.wpt != def.wpt) + throw StringError("MLX winotuner: greedy output seed no longer maps to the baked default"); + } + + auto scoreFn = [&](const std::vector& idx) -> double { + MLXWinograd::OutputUntransform cand{ tg0v[idx[0]], tg1v[idx[1]], wptv[idx[2]] }; + if(!isOutputCandidateValid(cand.tg0, cand.tg1, cand.wpt, outC, Ntiles)) + return std::numeric_limits::infinity(); + double t; + try { t = scoreOutputUntransform(cand, N, H, W, mi, useFP16, full); } + catch(const std::exception&) { return std::numeric_limits::infinity(); } +#ifdef MLX_TUNE_STUDY + std::fprintf(stderr, "[MLX-STUDY] out full=%d tg0=%d tg1=%d wpt=%d score=%.4f\n", + full ? 1 : 0, cand.tg0, cand.tg1, cand.wpt, t); +#endif + return t; + }; + + GreedySearch::Result gr = GreedySearch::coordinateDescent(axisSizes, order, seed, scoreFn, /*maxPasses=*/3); + best = MLXWinograd::OutputUntransform{ tg0v[gr.indices[0]], tg1v[gr.indices[1]], wptv[gr.indices[2]] }; + bestTime = gr.score; + considered = gr.evaluated; + } else { + auto cands = MLXWinogradTuner::buildOutputCandidatesForTesting(full, outC, Ntiles); + for(auto cand : cands) { + considered++; + double t; + try { + t = scoreOutputUntransform(cand, N, H, W, mi, useFP16, full); + } catch(const std::exception&) { + skipped++; + continue; + } +#ifdef MLX_TUNE_STUDY + std::fprintf(stderr, + "[MLX-STUDY] out full=%d tg0=%d tg1=%d wpt=%d score=%.4f\n", + full ? 1 : 0, cand.tg0, cand.tg1, cand.wpt, t); +#endif + if(t < bestTime) { bestTime = t; best = cand; } + } + } + // Symmetric to flatSweepInput; the output best-config has no vw/gridOrder. + std::string bestFields, perShapeStr; + if(best) { + bestFields = "tg0=" + std::to_string(best->tg0) + + " tg1=" + std::to_string(best->tg1) + + " wpt=" + std::to_string(best->wpt); + if(baselineMs >= 1e-9) + perShapeStr = renderPerShapeMs(scoreOutputUntransformPerShape(*best, N, H, W, mi, useFP16, full)); + } + logFlatSweep(logger, "flatSweepOutput", considered, skipped, + (bool)best, bestFields, bestTime, baselineMs, perShapeStr); + if(consideredOut) *consideredOut = considered; + return best; +} + +} // namespace + +MLXWinogradTuneParams MLXWinogradTuner::loadOrAutoTune( + string tunerFile, + const string& homeDataDirOverride, + const string& gpuName, + int nnXLen, int nnYLen, int batchSize, + ModelInfoForTuning modelInfo, + Logger* logger, + bool full, + bool reTune, + bool useFP16) { + if(tunerFile.empty()) { + string dir = defaultDirectory(true, homeDataDirOverride); + tunerFile = dir + "/" + defaultFileName(gpuName, nnXLen, nnYLen, + modelInfo.trunkNumChannels, + modelInfo.modelVersion, useFP16, full); + } + + // Cache load path: if the file exists, validates, and reTune is false, use it. + if(!reTune && !tunerFile.empty() && FileUtils::exists(tunerFile)) { + try { + MLXWinogradTuneParams loaded = MLXWinogradTuneParams::load(tunerFile); + if(loaded.isValid()) { + if(logger) + logger->write("Loaded MLX Winograd tuning parameters from " + tunerFile); + return loaded; + } + if(logger) + logger->write("MLX Winograd cache " + tunerFile + " failed isValid(); re-tuning"); + } catch(const IOError& e) { + if(logger) + logger->write(std::string("MLX Winograd cache load failed: ") + e.what() + "; re-tuning"); + } + } + + // Flat per-stage sweep. Each sweep logs its own considered-count via `logger`; + // the per-stage considered counters are no longer surfaced separately. + auto t0 = std::chrono::steady_clock::now(); + auto bestIn = flatSweepInput (batchSize, nnYLen, nnXLen, modelInfo, useFP16, full, /*useGreedy=*/!full, logger, /*consideredOut=*/nullptr); + auto bestOut = flatSweepOutput(batchSize, nnYLen, nnXLen, modelInfo, useFP16, full, /*useGreedy=*/!full, logger, /*consideredOut=*/nullptr); + auto t1 = std::chrono::steady_clock::now(); + double tuneMs = std::chrono::duration(t1 - t0).count(); + if(logger) + logger->write("MLX tuner flat sweep complete in " + Global::strprintf("%.0f", tuneMs) + " ms"); + + if(!bestIn || !bestOut) + throw StringError("MLXWinogradTuner: flat sweep returned no valid candidate"); + + MLXWinogradTuneParams result; + result.inputTransform = *bestIn; + result.outputUntransform = *bestOut; + // Global gridOrder is deleted; input gridOrder stands alone. + + if(!result.isValid()) + throw StringError("MLXWinogradTuner: flat sweep result failed isValid()"); + +#ifdef MLX_TUNE_STUDY + if(!full) { + // Coarse EXHAUSTIVE reference (same coarse value sets as greedy, but full + // search: useGreedy=false) — apples-to-apples with the greedy winner. This + // isolates "greedy vs coarse-exhaustive" from the separate "coarse vs wide" + // question (the coarse breadth is already accepted). Dev-only. + int exIn = 0, exOut = 0; + auto exBestIn = flatSweepInput (batchSize, nnYLen, nnXLen, modelInfo, useFP16, /*full=*/false, /*useGreedy=*/false, nullptr, &exIn); + auto exBestOut = flatSweepOutput(batchSize, nnYLen, nnXLen, modelInfo, useFP16, /*full=*/false, /*useGreedy=*/false, nullptr, &exOut); + double greedyInMs = scoreInputTransformForTesting (result.inputTransform, batchSize, nnYLen, nnXLen, modelInfo, useFP16); + double greedyOutMs = scoreOutputUntransformForTesting(result.outputUntransform, batchSize, nnYLen, nnXLen, modelInfo, useFP16); + double exInMs = exBestIn ? scoreInputTransformForTesting (*exBestIn, batchSize, nnYLen, nnXLen, modelInfo, useFP16) : 0.0; + double exOutMs = exBestOut ? scoreOutputUntransformForTesting(*exBestOut, batchSize, nnYLen, nnXLen, modelInfo, useFP16) : 0.0; + double gT = greedyInMs + greedyOutMs, eT = exInMs + exOutMs; + double deltaPct = (eT > 1e-9) ? (gT - eT) / eT * 100.0 : 0.0; + std::fprintf(stderr, + "[MLX-ACCEPT] greedy_ms=%.4f coarse_exhaustive_ms=%.4f delta_pct=%+.1f within5=%d\n", + gT, eT, deltaPct, (deltaPct <= 5.0) ? 1 : 0); + } +#endif + + if(!tunerFile.empty()) { + MLXWinogradTuneParams::save(tunerFile, result); + if(logger) + logger->write("Saved MLX Winograd tuning parameters to " + tunerFile); + } + return result; +} + +std::vector +MLXWinogradTuner::buildInputCandidatesForTesting(bool full, int C, int Ntiles, MLXWinograd::GridOrder go) { + return buildInputCandidates(full, C, Ntiles, go); +} +std::vector +MLXWinogradTuner::buildOutputCandidatesForTesting(bool full, int outC, int Ntiles) { + return buildOutputCandidates(full, outC, Ntiles); +} + +std::vector +MLXWinogradTuner::planShapeRotationForTesting( + const std::vector>& histogram, bool full) { + return planShapeRotation(histogram, full); +} + +double MLXWinogradTuner::scoreInputTransformForTesting( + const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full) { + return scoreInputTransform(cfg, N, H, W, mi, useFP16, full); +} + +double MLXWinogradTuner::scoreOutputUntransformForTesting( + const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full) { + return scoreOutputUntransform(cfg, N, H, W, mi, useFP16, full); +} + +std::vector> +MLXWinogradTuner::scoreInputTransformPerShapeForTesting( + const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full) { + return scoreInputTransformPerShape(cfg, N, H, W, mi, useFP16, full); +} + +std::vector> +MLXWinogradTuner::scoreOutputUntransformPerShapeForTesting( + const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full) { + return scoreOutputUntransformPerShape(cfg, N, H, W, mi, useFP16, full); +} + +std::string MLXWinogradTuner::formatConv3x3DistributionLine( + int total, + const std::map& inputChannelCounts, + const std::map& outputChannelCounts) { + // Build a deterministic ordering: pairs sorted descending by invocation + // count, ties broken by channel count descending. Truncate each histogram + // to top-10 with a trailing ",..." guard for pathological models. + auto serialize = [](const std::map& counts) -> std::string { + if(counts.empty()) return "{}"; + std::vector> pairs(counts.begin(), counts.end()); + std::sort(pairs.begin(), pairs.end(), + [](const std::pair& a, const std::pair& b) { + if(a.second != b.second) return a.second > b.second; + return a.first > b.first; + }); + constexpr size_t kMax = 10; + bool truncated = pairs.size() > kMax; + if(truncated) pairs.resize(kMax); + + std::string s; + for(size_t i = 0; i < pairs.size(); i++) { + if(i > 0) s += ","; + s += std::to_string(pairs[i].first) + ":" + std::to_string(pairs[i].second); + } + if(truncated) s += ",..."; + return s; + }; + + return "MLX tuner conv3x3 distribution: total=" + std::to_string(total) + + " input_c=" + serialize(inputChannelCounts) + + " output_c=" + serialize(outputChannelCounts); +} + +// Pure core: filter to 3x3 convs and emit (channels, count) histograms. +// Decoupled from ModelDesc so it's testable without synthesizing the +// copy-deleted ModelDesc hierarchy. Takes pointers because ConvLayerDesc +// has a deleted copy ctor; pointers must be non-null and outlive the call. +static std::pair>, + std::vector>> +buildConv3x3HistogramsFromConvs(const std::vector& convs) { + std::map inputC, outputC; + for(const ConvLayerDesc* c : convs) { + if(c->convXSize == 3 && c->convYSize == 3) { + inputC[c->inChannels]++; + outputC[c->outChannels]++; + } + } + std::vector> inVec(inputC.begin(), inputC.end()); + std::vector> outVec(outputC.begin(), outputC.end()); + return {std::move(inVec), std::move(outVec)}; +} + +std::pair>, + std::vector>> +MLXWinogradTuner::buildConv3x3HistogramsFromConvsForTesting( + const std::vector& convs) { + return buildConv3x3HistogramsFromConvs(convs); +} + +// ModelDesc shim. Walks iterConvLayers, collects pointers to the +// descriptors owned by modelDesc, and delegates to the pure core. Used +// by mlxbackend.cpp at model load. The returned histograms reference no +// memory from modelDesc — only ints — so the descriptor lifetime +// requirement is local to this call. +std::pair>, + std::vector>> +MLXWinogradTuner::buildConv3x3Histograms(const ModelDesc& modelDesc) { + std::vector convs; + modelDesc.iterConvLayers([&](const ConvLayerDesc& c) { convs.push_back(&c); }); + return buildConv3x3HistogramsFromConvs(convs); +} + +std::string MLXWinogradTuner::formatConv3x3Distribution(const ModelDesc& modelDesc) { + // Convenience wrapper for callers that want the formatted line directly + // from a ModelDesc. The histogram is built here and (separately) again by + // mlxbackend.cpp for the tuner's ModelInfoForTuning — two walks per model + // load. This is acceptable because model load happens once per process; + // a single-walk refactor would tangle the mlxbackend call site without + // measurable savings. + auto [inVec, outVec] = MLXWinogradTuner::buildConv3x3Histograms(modelDesc); + std::map inMap(inVec.begin(), inVec.end()); + std::map outMap(outVec.begin(), outVec.end()); + int total = 0; + for(const auto& kv : outVec) total += kv.second; // total = #3x3 convs + return formatConv3x3DistributionLine(total, inMap, outMap); +} + +#endif // USE_MLX_BACKEND diff --git a/cpp/neuralnet/mlxwinotuner.h b/cpp/neuralnet/mlxwinotuner.h new file mode 100644 index 000000000..ac1b7d88f --- /dev/null +++ b/cpp/neuralnet/mlxwinotuner.h @@ -0,0 +1,179 @@ +#ifndef NEURALNET_MLXWINOTUNER_H_ +#define NEURALNET_MLXWINOTUNER_H_ + +#ifdef USE_MLX_BACKEND + +#include +#include +#include +#include +#include "../neuralnet/mlxwinograd.h" + +class Logger; +struct ModelDesc; +struct ConvLayerDesc; + +struct MLXWinogradTuneParams { + MLXWinograd::InputTransform inputTransform; + MLXWinograd::OutputUntransform outputUntransform; + + // tg0 * tg1 <= 1024, all positive. Input gridOrder stands alone (no global + // companion; output kernel is Cfast-monomorphic). + // vw must divide the fast-axis dim of the current model — + // that check happens at candidate-enumeration time, not here. + bool isValid() const; + + // VERSION=3 plain-text persistence. Format: + // VERSION=3 + // #inputTransform + // tg0= tg1= wpt= vw= gridOrder=<0|1> + // #outputUntransform + // tg0= tg1= wpt= + static void save(const std::string& filename, const MLXWinogradTuneParams& params); + static MLXWinogradTuneParams load(const std::string& filename); +}; + +namespace MLXWinogradTuner { + struct ModelInfoForTuning { + int trunkNumChannels; // cache file key + int modelVersion; // cache file key + std::vector> conv3x3InputHistogram; + std::vector> conv3x3OutputHistogram; + }; + + // Per-shape rep allocation produced by planShapeRotation. The tuner loops + // over a vector when scoring a candidate: each entry contributes + // `weight * median(time over `measureReps` reps at this channel count)` to + // the total score. + struct ShapePlan { + int channels; // C value to time + int measureReps; // number of timing reps (does not include warmup) + double weight; // normalized score weight, Σ weights == 1.0 + }; + + // Pure, deterministic. Given (channel, count) pairs, returns the planned + // rotation: + // 1. work_i = count_i * channels_i; sort desc by work; take top-3. + // 2. drop shapes with work < 3% of the post-top3 total work; renormalize. + // 3. weight_i = work_i / total_work after renormalization. + // 4. allocate the rep budget proportionally; bump any below the floor up to + // the floor, taking the deficit from the dominant shape; repair rounding + // so the dominant absorbs the +/-1 to make Σ measureReps == budget exactly. + // Budget/floor are 19/3 for full=true, 7/2 for full=false (model load). + // Asserts on empty input. + // full selects the rep budget (true: 19-rep precise; false: 7-rep coarse, + // the per-model-load path). Defaults true so existing call sites are + // unaffected; pass false to exercise the coarse-budget allocation. + std::vector planShapeRotationForTesting( + const std::vector>& histogram, bool full = true); + + // Chip-specific identifier for the cache-file key (e.g. "Apple M3 Max" via + // sysctl machdep.cpu.brand_string). The optimal Winograd launch geometry + // differs across Apple GPU variants, so different chips must not share one + // tuned cache; defaultFileName() strips this to [A-Za-z0-9]. Both the backend + // load path and the `tuner` command call this so their keys always match. + // Falls back to "AppleSilicon" if the query fails. + std::string detectGpuName(); + + std::string defaultDirectory(bool makeDir, const std::string& homeDataDirOverride); + // full selects the cache-file variant: false → coarse "fast" tune (legacy + // name, no mode suffix), true → wide "full" tune ("_full" suffix). The two + // produce different winners so they must not share a file. Defaults false. + std::string defaultFileName(const std::string& gpuName, + int nnXLen, int nnYLen, + int trunkNumChannels, int modelVersion, + bool useFP16, bool full = false); + + // Loads existing tune file if present and valid; otherwise runs the two + // grid searches, saves the result, and returns it. + // useFP16: passed to defaultFileName for cache-file naming AND to the + // search-timing kernels so geometry is measured at the active precision. + MLXWinogradTuneParams loadOrAutoTune( + std::string tunerFile, + const std::string& homeDataDirOverride, + const std::string& gpuName, + int nnXLen, int nnYLen, int batchSize, + ModelInfoForTuning modelInfo, + Logger* logger, + bool full, + bool reTune, + bool useFP16 + ); + + // Test-only — exposes the per-model candidate enumeration. Not part of the + // stable API; production callers should use loadOrAutoTune. + std::vector + buildInputCandidatesForTesting(bool full, int C, int Ntiles, MLXWinograd::GridOrder go); + std::vector + buildOutputCandidatesForTesting(bool full, int outC, int Ntiles); + + // Test-only — exposes the per-stage scoring primitives so tests can compare + // configs apples-to-apples without depending on the full tuner measurement path. + double scoreInputTransformForTesting(const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full = true); + double scoreOutputUntransformForTesting(const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full = true); + + // Per-shape median timing for diagnostic logging. Same rotation as the + // scoring functions, but reports median per planned shape instead of a + // single weighted score. One entry per shape in planShapeRotation's + // output, in the same order (dominant first). Used by the flat-sweep + // log "shape_ms=" field and the gated per-shape consistency test. + std::vector> + scoreInputTransformPerShapeForTesting(const MLXWinograd::InputTransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full = true); + std::vector> + scoreOutputUntransformPerShapeForTesting(const MLXWinograd::OutputUntransform& cfg, + int N, int H, int W, + const ModelInfoForTuning& mi, + bool useFP16, bool full = true); + + // Conv-3x3 shape distribution log: one-line summary of the model's 3x3 + // conv shape mix, computed at model load and printed alongside the tuner + // log so operators can correlate cached winners with the per-pass shape + // distribution the cache was tuned for. Pure formatter is exposed for + // testability; wrapper does the descriptor walk. + // + // formatConv3x3DistributionLine: pure function — given pre-computed + // histograms keyed by channel count, returns the log line. No I/O. + std::string formatConv3x3DistributionLine( + int total, + const std::map& inputChannelCounts, + const std::map& outputChannelCounts); + + // formatConv3x3Distribution: delegates to buildConv3x3Histograms, then + // rebuilds maps and calls formatConv3x3DistributionLine. Single line; + // safe to log on every model load. + std::string formatConv3x3Distribution(const ModelDesc& modelDesc); + + // Pure core of the conv-3x3 histogram build: filters to 3x3, returns + // (channels, count) vectors for inputs and outputs. Decoupled from + // ModelDesc so it can be tested without synthesizing the + // copy-deleted/stream-constructed ModelDesc hierarchy. + // + // NOTE on the pointer signature: ConvLayerDesc has a deleted copy ctor + // (desc.h:29), so we cannot collect them by value. The shim collects + // pointers to descriptors owned by the ModelDesc; the test constructs + // descriptors in a local vector via emplace_back and passes pointers. + // All pointers must be non-null and outlive the call. + std::pair>, + std::vector>> + buildConv3x3HistogramsFromConvsForTesting( + const std::vector& convs); + + // ModelDesc shim. Walks modelDesc.iterConvLayers into a pointer vector + // and delegates to the pure core above. Used by mlxbackend.cpp at model + // load. + std::pair>, + std::vector>> + buildConv3x3Histograms(const ModelDesc& modelDesc); +} + +#endif // USE_MLX_BACKEND +#endif // NEURALNET_MLXWINOTUNER_H_ diff --git a/cpp/program/setup.cpp b/cpp/program/setup.cpp index 8a39b9fa6..5761caba2 100644 --- a/cpp/program/setup.cpp +++ b/cpp/program/setup.cpp @@ -21,6 +21,7 @@ std::vector Setup::getBackendPrefixes() { prefixes.push_back("metal"); prefixes.push_back("opencl"); prefixes.push_back("eigen"); + prefixes.push_back("mlx"); prefixes.push_back("dummybackend"); return prefixes; } @@ -89,6 +90,8 @@ vector Setup::initializeNNEvaluators( string backendPrefix = "opencl"; #elif defined(USE_EIGEN_BACKEND) string backendPrefix = "eigen"; + #elif defined(USE_MLX_BACKEND) + string backendPrefix = "mlx"; #else string backendPrefix = "dummybackend"; #endif diff --git a/cpp/rungpuerrortest.sh b/cpp/rungpuerrortest.sh index 9827d2a7c..1372617b7 100755 --- a/cpp/rungpuerrortest.sh +++ b/cpp/rungpuerrortest.sh @@ -1,10 +1,20 @@ #!/bin/bash -eux -# Optional first argument: extra config overrides appended (comma-separated) to the -# -override-config of every test entry below. E.g. ./rungpuerrortest.sh "useFP16=true, useNHWC=true" -EXTRA_OVERRIDE="" -if [ -n "${1:-}" ]; then - EXTRA_OVERRIDE=", $1" +# Usage: $0 [gpu|ane] [extra-override] +# gpu (default) — run against the GPU path of the active backend (Metal or MLX) +# ane — run against the CoreML/ANE path of the active backend (Metal or MLX) +# Result files are suffixed (_ane) so the two runs can coexist; reference files +# under $REFERENCEDIR are backend-independent and shared. +# Optional second argument: extra config overrides appended (comma-separated) to the +# -override-config of every test entry below. E.g. ./rungpuerrortest.sh gpu "useFP16=true, useNHWC=true" +MODE="${1:-gpu}" +case "$MODE" in + gpu) EXTRA_OVERRIDE=""; SUFFIX="" ;; + ane) EXTRA_OVERRIDE=", deviceToUseThread0=100"; SUFFIX="_ane" ;; + *) echo "Usage: $0 [gpu|ane] [extra-override]" >&2; exit 1 ;; +esac +if [ -n "${2:-}" ]; then + EXTRA_OVERRIDE="${EXTRA_OVERRIDE}, $2" fi REFERENCEDIR="tests/results/gpu_error_reference_files" @@ -48,130 +58,130 @@ MODELBASE12=$(basename "$MODEL12") ./katago testgpuerror -model "$MODEL1" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE1"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE1"_size9.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE1"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE1"_size9${SUFFIX}.txt ./katago testgpuerror -model "$MODEL1" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=False, maxBatchSize=16${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE1"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE1"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE1"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE1"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL2" -config configs/gtp_example.cfg -boardsize 13 \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE2"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE2"_size13.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE2"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE2"_size13${SUFFIX}.txt ./katago testgpuerror -model "$MODEL2" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True, maxBatchSize=19${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE2"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE2"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE2"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE2"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL3" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=False, maxBatchSize=32${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE3"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE3"_size9.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE3"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE3"_size9${SUFFIX}.txt ./katago testgpuerror -model "$MODEL3" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=False, maxBatchSize=2${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE3"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE3"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE3"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE3"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=True, maxBatchSize=3${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size9.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size9${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize 13 \ -override-config "requireMaxBoardSize=False, maxBatchSize=27${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size13.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size13${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize 10x14 \ -override-config "requireMaxBoardSize=True, maxBatchSize=13${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size10x14.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size10x14${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE4"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE4"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL4" -config configs/gtp_example.cfg -boardsize 13 \ -override-config "requireMaxBoardSize=False,maxBoardXSizeForNNBuffer=18,maxBoardYSizeForNNBuffer=19${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size13_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size13_rectbuffer.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE4"_size13_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE4"_size13_rectbuffer${SUFFIX}.txt ./katago testgpuerror -model "$MODEL5" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE5"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE5"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE5"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE5"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL5" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True, maxBatchSize=16${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE5"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE5"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE5"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE5"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL6" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size9.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size9${SUFFIX}.txt ./katago testgpuerror -model "$MODEL6" -config configs/gtp_example.cfg -boardsize 13 \ -override-config "requireMaxBoardSize=False, maxBatchSize=28${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size13.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size13${SUFFIX}.txt ./katago testgpuerror -model "$MODEL6" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=False, maxBatchSize=8${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL6" -config configs/gtp_example.cfg -boardsize 10x14 \ -override-config "requireMaxBoardSize=False, maxBatchSize=15${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size10x14.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE6"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE6"_size10x14${SUFFIX}.txt ./katago testgpuerror -model "$MODEL6" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE6"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE6"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE6"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE6"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=False, maxBatchSize=4${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size9.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size9.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size9${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize 13 \ -override-config "requireMaxBoardSize=True, maxBatchSize=29${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size13.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size13.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size13${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize 10x14 \ -override-config "requireMaxBoardSize=True, maxBatchSize=5${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size10x14.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size10x14.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size10x14${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE7"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE7"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=False,maxBoardXSizeForNNBuffer=16,maxBoardYSizeForNNBuffer=11, maxBatchSize=9${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size9_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size9_rectbuffer.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_size9_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE7"_size9_rectbuffer${SUFFIX}.txt ./katago testgpuerror -model "$MODEL8" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False, maxBatchSize=11${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE8"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE8"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE8"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE8"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL8" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE8"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE8"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE8"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE8"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL7" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False,policyOptimism=0.65,playoutDoublingAdvantage=0.3,nnPolicyTemperature=1.1${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE7"_sizerect_weirdsettings.txt | tee "$RESULTSDIR"/"$MODELBASE7"_sizerect_weirdsettings.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE7"_sizerect_weirdsettings.txt | tee "$RESULTSDIR"/"$MODELBASE7"_sizerect_weirdsettings${SUFFIX}.txt ./katago testgpuerror -model "$MODEL9" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False, maxBatchSize=11${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE9"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE9"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE9"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE9"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL9" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE9"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE9"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE9"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE9"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL10" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False, maxBatchSize=11${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE10"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE10"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE10"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE10"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL10" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE10"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE10"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE10"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE10"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL11" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False, maxBatchSize=11${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE11"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE11"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE11"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE11"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL11" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE11"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE11"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE11"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE11"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL12" -config configs/gtp_example.cfg -boardsize rectangle \ -override-config "requireMaxBoardSize=False, maxBatchSize=12${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE12"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE12"_sizerect.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE12"_sizerect.txt | tee "$RESULTSDIR"/"$MODELBASE12"_sizerect${SUFFIX}.txt ./katago testgpuerror -model "$MODEL12" -config configs/gtp_example.cfg -boardsize 19 \ -override-config "requireMaxBoardSize=True${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE12"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE12"_size19.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE12"_size19.txt | tee "$RESULTSDIR"/"$MODELBASE12"_size19${SUFFIX}.txt ./katago testgpuerror -model "$MODEL11" -config configs/gtp_example.cfg -boardsize 9 \ -override-config "requireMaxBoardSize=False,maxBoardXSizeForNNBuffer=16,maxBoardYSizeForNNBuffer=11,maxBatchSize=15,policyOptimism=0.70${EXTRA_OVERRIDE}" \ - -reference-file "$REFERENCEDIR"/"$MODELBASE11"_size9_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE11"_size9_rectbuffer.txt + -reference-file "$REFERENCEDIR"/"$MODELBASE11"_size9_rectbuffer.txt | tee "$RESULTSDIR"/"$MODELBASE11"_size9_rectbuffer${SUFFIX}.txt