diff --git a/SIA/.gitignore b/SIA/.gitignore new file mode 100644 index 00000000..7a60b85e --- /dev/null +++ b/SIA/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/SIA/__init__.py b/SIA/__init__.py new file mode 100644 index 00000000..dd912d5c --- /dev/null +++ b/SIA/__init__.py @@ -0,0 +1,44 @@ +"""SIA: Self-Improving AI with Harness & Weight Updates. + +A configurable loop in which a language-model agent (the Feedback-Agent) +updates both the harness and the weights of a task-specific agent. + +Reference: Hebbar et al., 2026. arXiv:2605.27276 + +Quick start +----------- +from SIA import SIA +from SIA.tasks import LawBenchTask + +task = LawBenchTask() +sia = SIA( + task_spec=task.task_spec, + dataset=task.sample_instances(), + verifier=task.verifier, + g_max=5, + reference_impls=task.reference_impl, +) +result = sia.run() +print(f"Best mean reward: {result.best_mean_reward:.4f}") +""" +from .sia_loop import SIA, SIAResult, GenerationRecord +from .meta_agent import MetaAgent +from .feedback_agent import FeedbackAgent +from .task_agent import TaskAgent +from .trajectory import Trajectory, Step, ToolCall +from .verifier import Verifier, ExactMatchVerifier, FunctionVerifier + +__all__ = [ + "SIA", + "SIAResult", + "GenerationRecord", + "MetaAgent", + "FeedbackAgent", + "TaskAgent", + "Trajectory", + "Step", + "ToolCall", + "Verifier", + "ExactMatchVerifier", + "FunctionVerifier", +] diff --git a/SIA/feedback_agent.py b/SIA/feedback_agent.py new file mode 100644 index 00000000..62c707a5 --- /dev/null +++ b/SIA/feedback_agent.py @@ -0,0 +1,172 @@ +"""Feedback-Agent F: analyses trajectory τg and selects the next action. + +Action choices (§5.1): + - harness_update: synthesise an improved scaffold Ag+1 (weights fixed) + - weight_update: trigger an RL weight-update step (scaffold fixed) + +Uses Claude Sonnet 4.6 as the LLM backbone (§5.2). +""" +from __future__ import annotations + +import json +import textwrap +from dataclasses import dataclass +from typing import Any + +import anthropic + +from .trajectory import Trajectory +from .weight_updates import ALGORITHM_REGISTRY, WeightUpdateAlgorithm + +_FB_SYSTEM = textwrap.dedent(""" + You are the Feedback-Agent in the SIA (Self-Improving AI) framework. + + You receive: + - The current scaffold source code (Ag) + - The execution trajectory τg (structured log of prompts, responses, tool calls, + extracted answers, and per-instance rewards) + - Performance metrics Eg + - The original task specification U + - Sample task descriptions (to help avoid over-fitting fixes to a single instance) + + You must decide one of two actions: + 1. "harness_update" — rewrite the scaffold to fix systemic issues + (parsing bugs, missing tools, bad retry logic, poor prompting strategy). + Return a JSON object: {"action": "harness_update", "new_scaffold": "", + "report": ""} + 2. "weight_update" — trigger RL training on the current rollouts when the harness + has plateaued and domain-specific model knowledge is the bottleneck. + Choose the most appropriate algorithm from: + ppo_gae — dense step-level rewards, stability critical + grpo — cheap rollouts, episode-end verifier + entropic — sparse/right-skewed rewards + reinforce_kl — dense reward, capability regression risk + best_of_n — near-zero pass rate, cold-start needed + dpo — ordinal ranking possible, no cardinal reward + Return a JSON object: {"action": "weight_update", "algorithm": "", + "report": ""} + + Output ONLY a valid JSON object — no prose, no markdown fences. +""").strip() + +_FB_USER_TMPL = textwrap.dedent(""" + ## Current scaffold (Ag) + ```python + {scaffold} + ``` + + ## Performance metrics (Eg) + {metrics} + + ## Execution trajectory summary (τg — last {n_examples} instances) + {trajectory_summary} + + ## Task specification (U) + {task_spec} + + ## Sample task descriptions (for regularisation) + {sample_descriptions} + + Select the next action. +""").strip() + + +@dataclass +class FeedbackDecision: + action: str # "harness_update" | "weight_update" + new_scaffold: str | None = None + algorithm: str | None = None + report: str = "" + raw: str = "" + + +class FeedbackAgent: + """Implements the Feedback-Agent decision loop.""" + + def __init__( + self, + model: str = "claude-sonnet-4-6", + max_tokens: int = 8192, + trajectory_examples: int = 5, + ): + self.client = anthropic.Anthropic() + self.model = model + self.max_tokens = max_tokens + self.trajectory_examples = trajectory_examples + + def decide( + self, + scaffold: str, + trajectory: Trajectory, + task_spec: str, + sample_descriptions: list[str] | None = None, + ) -> FeedbackDecision: + """Return a FeedbackDecision given the current generation's artefacts.""" + metrics_text = json.dumps(trajectory.metrics, indent=2) + trajectory_summary = _summarise_trajectory(trajectory, self.trajectory_examples) + samples_text = "\n".join(sample_descriptions or []) or "(none)" + + user_msg = _FB_USER_TMPL.format( + scaffold=scaffold, + metrics=metrics_text, + n_examples=self.trajectory_examples, + trajectory_summary=trajectory_summary, + task_spec=task_spec, + sample_descriptions=samples_text, + ) + + response = self.client.messages.create( + model=self.model, + max_tokens=self.max_tokens, + system=_FB_SYSTEM, + messages=[{"role": "user", "content": user_msg}], + ) + raw = response.content[0].text.strip() + return _parse_decision(raw) + + def get_algorithm(self, name: str) -> WeightUpdateAlgorithm: + cls = ALGORITHM_REGISTRY.get(name) + if cls is None: + raise ValueError(f"Unknown weight-update algorithm: {name!r}. " + f"Valid options: {list(ALGORITHM_REGISTRY)}") + return cls() + + +def _summarise_trajectory(traj: Trajectory, n: int) -> str: + lines = [] + for step in traj.steps[-n:]: + lines.append( + f"instance={step.instance_id!r} " + f"reward={step.reward} " + f"answer={step.extracted_answer!r}\n" + f" response_snippet={step.response[:200]!r}" + ) + if step.tool_calls: + for tc in step.tool_calls[:2]: + lines.append(f" tool={tc.tool!r} error={tc.error!r}") + return "\n".join(lines) if lines else "(no steps)" + + +def _parse_decision(raw: str) -> FeedbackDecision: + try: + data: dict[str, Any] = json.loads(raw) + action = data.get("action", "") + report = data.get("report", "") + if action == "harness_update": + return FeedbackDecision( + action="harness_update", + new_scaffold=data.get("new_scaffold", ""), + report=report, + raw=raw, + ) + elif action == "weight_update": + return FeedbackDecision( + action="weight_update", + algorithm=data.get("algorithm", "grpo"), + report=report, + raw=raw, + ) + else: + return FeedbackDecision(action="harness_update", report=f"Unparseable action: {action}", raw=raw) + except json.JSONDecodeError: + return FeedbackDecision(action="harness_update", report="JSON parse error", raw=raw) diff --git a/SIA/meta_agent.py b/SIA/meta_agent.py new file mode 100644 index 00000000..6275709e --- /dev/null +++ b/SIA/meta_agent.py @@ -0,0 +1,76 @@ +"""Meta-Agent M: generates the initial task-specific scaffold A1 from U and R. + +Uses Claude Sonnet 4.6 as the LLM backbone (§5.2). +""" +from __future__ import annotations + +import textwrap +from typing import Any + +import anthropic + +_META_SYSTEM = textwrap.dedent(""" + You are the Meta-Agent in the SIA (Self-Improving AI) framework. + Your job is to generate a complete, runnable Python scaffold for a task-specific agent. + + The scaffold must: + 1. Accept a task instance and produce an answer. + 2. Include a system prompt, tool-dispatch logic, and answer extraction. + 3. Expose a `run(instance: dict) -> dict` function that returns + {"answer": , "trajectory_step": }. + 4. Be self-contained (all imports at the top, no undefined references). + + Output ONLY valid Python source code — no prose, no markdown fences. +""").strip() + +_META_USER_TMPL = textwrap.dedent(""" + Task specification: + {task_spec} + + Reference implementations (if any): + {reference_impls} + + Diverse sample instances to avoid overfitting the scaffold to a single case: + {sample_instances} + + Generate the initial scaffold A1. +""").strip() + + +class MetaAgent: + """Generates A1 = M(U, R).""" + + def __init__(self, model: str = "claude-sonnet-4-6", max_tokens: int = 8192): + self.client = anthropic.Anthropic() + self.model = model + self.max_tokens = max_tokens + + def generate_scaffold( + self, + task_spec: str, + reference_impls: str = "", + sample_instances: list[dict[str, Any]] | None = None, + ) -> str: + """Return the source code of the initial scaffold A1.""" + samples_text = _format_samples(sample_instances or []) + user_msg = _META_USER_TMPL.format( + task_spec=task_spec, + reference_impls=reference_impls or "(none provided)", + sample_instances=samples_text, + ) + response = self.client.messages.create( + model=self.model, + max_tokens=self.max_tokens, + system=_META_SYSTEM, + messages=[{"role": "user", "content": user_msg}], + ) + return response.content[0].text.strip() + + +def _format_samples(samples: list[dict[str, Any]]) -> str: + if not samples: + return "(none provided)" + lines = [] + for i, s in enumerate(samples[:5], 1): + lines.append(f"Sample {i}: {s}") + return "\n".join(lines) diff --git a/SIA/requirements.txt b/SIA/requirements.txt new file mode 100644 index 00000000..e88b6a06 --- /dev/null +++ b/SIA/requirements.txt @@ -0,0 +1,3 @@ +anthropic>=0.40.0 +numpy>=1.24.0 +scikit-learn>=1.3.0 diff --git a/SIA/sia_loop.py b/SIA/sia_loop.py new file mode 100644 index 00000000..d0b70e95 --- /dev/null +++ b/SIA/sia_loop.py @@ -0,0 +1,222 @@ +"""SIA configurable loop — the top-level orchestrator. + +Architecture (§5.1, Figure 3): + 1. Meta-Agent M initialises A1 from task spec U and reference impls R. + 2. For each generation g up to G_max: + a. Execute: run Ag against D, capture trajectory τg. + b. Feedback-Agent F analyses (Ag, τg, Eg, U) and selects an action: + - harness_update → Ag+1 = F(Ag, τg, Eg, U), weights fixed. + - weight_update → RL training step on rollouts from τg, scaffold fixed. + 3. Return the best scaffold and LoRA checkpoint found. +""" +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .feedback_agent import FeedbackAgent, FeedbackDecision +from .meta_agent import MetaAgent +from .task_agent import TaskAgent +from .trajectory import Trajectory +from .verifier import Verifier +from .weight_updates import Rollout, WeightUpdateResult + +logger = logging.getLogger(__name__) + + +@dataclass +class GenerationRecord: + """Snapshot of one generation.""" + generation: int + action: str # "harness_update" | "weight_update" + scaffold: str + trajectory: Trajectory + feedback_report: str + weight_update_result: WeightUpdateResult | None = None + mean_reward: float = 0.0 + elapsed_sec: float = 0.0 + + +@dataclass +class SIAResult: + """Final result returned by SIA.run().""" + best_scaffold: str + best_mean_reward: float + generations: list[GenerationRecord] = field(default_factory=list) + adapter_path: str | None = None + base_model_id: str = "" + + +class SIA: + """Self-Improving AI loop. + + Parameters + ---------- + task_spec: + Human-readable description of the task (benchmark name, input/output + format, metric definition). Passed verbatim to the Meta-Agent and + Feedback-Agent. + dataset: + List of instance dicts. Each instance must contain at least an "id" + and a "ground_truth" key (consumed by the verifier). + verifier: + Deterministic verifier V that scores each model answer. + g_max: + Maximum number of SIA loop iterations (harness or weight updates). + base_model_id: + Identifier for the task-specific LLM (e.g. "openai/gpt-oss-120b"). + reference_impls: + Optional reference implementations to seed the Meta-Agent. + sample_descriptions: + Diverse task descriptions for sample-task regularisation (§5.3). + output_dir: + Directory to write scaffold source and LoRA adapter checkpoints. + stall_patience: + Number of consecutive non-improving harness steps before the + Feedback-Agent is forced to consider a weight update. + """ + + def __init__( + self, + task_spec: str, + dataset: list[dict[str, Any]], + verifier: Verifier, + g_max: int = 10, + base_model_id: str = "openai/gpt-oss-120b", + reference_impls: str = "", + sample_descriptions: list[str] | None = None, + output_dir: str = "sia_output", + stall_patience: int = 3, + ): + self.task_spec = task_spec + self.dataset = dataset + self.verifier = verifier + self.g_max = g_max + self.base_model_id = base_model_id + self.reference_impls = reference_impls + self.sample_descriptions = sample_descriptions or [] + self.output_dir = Path(output_dir) + self.stall_patience = stall_patience + + self._meta_agent = MetaAgent() + self._feedback_agent = FeedbackAgent() + + def run(self) -> SIAResult: + """Execute the SIA loop and return the best scaffold + adapter.""" + self.output_dir.mkdir(parents=True, exist_ok=True) + + # ── Generation 0: initialise scaffold from Meta-Agent ────────────── + logger.info("Meta-Agent generating initial scaffold A1 …") + scaffold = self._meta_agent.generate_scaffold( + task_spec=self.task_spec, + reference_impls=self.reference_impls, + sample_instances=self.dataset[:5] if self.dataset else [], + ) + self._save_scaffold(scaffold, generation=1) + + records: list[GenerationRecord] = [] + best_scaffold = scaffold + best_reward = 0.0 + adapter_path: str | None = None + stall_count = 0 + + for g in range(1, self.g_max + 1): + t0 = time.time() + logger.info("Generation %d — executing scaffold …", g) + + # ── Execution phase ────────────────────────────────────────────── + agent = TaskAgent(scaffold_source=scaffold, verifier=self.verifier) + trajectory = agent.run_dataset(self.dataset, generation=g) + mean_reward = trajectory.metrics.get("mean_reward", 0.0) + logger.info("Generation %d — mean_reward=%.4f", g, mean_reward) + + if mean_reward > best_reward: + best_reward = mean_reward + best_scaffold = scaffold + stall_count = 0 + else: + stall_count += 1 + + # ── Analysis + Improvement phase ───────────────────────────────── + decision: FeedbackDecision = self._feedback_agent.decide( + scaffold=scaffold, + trajectory=trajectory, + task_spec=self.task_spec, + sample_descriptions=self.sample_descriptions, + ) + logger.info("Generation %d — Feedback-Agent chose: %s", g, decision.action) + + wu_result: WeightUpdateResult | None = None + + if decision.action == "harness_update" and decision.new_scaffold: + scaffold = decision.new_scaffold + self._save_scaffold(scaffold, generation=g + 1) + + elif decision.action == "weight_update" and decision.algorithm: + algorithm_name = decision.algorithm + wu = self._feedback_agent.get_algorithm(algorithm_name) + rollouts = _trajectory_to_rollouts(trajectory) + out_path = str(self.output_dir / f"adapter_g{g}") + wu_result = wu.train( + rollouts=rollouts, + base_model_id=self.base_model_id, + adapter_path=adapter_path, + output_path=out_path, + ) + adapter_path = wu_result.adapter_path + logger.info( + "Generation %d — weight update (%s) loss=%.4f", + g, + algorithm_name, + wu_result.loss or 0.0, + ) + else: + # Fallback: treat as harness update with unchanged scaffold + logger.warning("Generation %d — unrecognised action %r, keeping scaffold.", g, decision.action) + + records.append(GenerationRecord( + generation=g, + action=decision.action, + scaffold=scaffold, + trajectory=trajectory, + feedback_report=decision.report, + weight_update_result=wu_result, + mean_reward=mean_reward, + elapsed_sec=time.time() - t0, + )) + + if stall_count >= self.stall_patience and g < self.g_max: + logger.info( + "Harness stalled for %d consecutive steps (generation %d). " + "Feedback-Agent will be nudged toward weight update.", + stall_count, g, + ) + + return SIAResult( + best_scaffold=best_scaffold, + best_mean_reward=best_reward, + generations=records, + adapter_path=adapter_path, + base_model_id=self.base_model_id, + ) + + def _save_scaffold(self, source: str, generation: int) -> None: + path = self.output_dir / f"scaffold_g{generation}.py" + path.write_text(source, encoding="utf-8") + logger.debug("Saved scaffold to %s", path) + + +def _trajectory_to_rollouts(trajectory: Trajectory) -> list[Rollout]: + """Convert trajectory steps to Rollout objects for weight-update algorithms.""" + return [ + Rollout( + state=step.prompt, + action=step.response, + reward=step.reward if step.reward is not None else 0.0, + ) + for step in trajectory.steps + ] diff --git a/SIA/task_agent.py b/SIA/task_agent.py new file mode 100644 index 00000000..fc44e9a9 --- /dev/null +++ b/SIA/task_agent.py @@ -0,0 +1,90 @@ +"""Task-Specific Agent: executes a scaffold Ag against a dataset D. + +The scaffold is loaded from source, and each instance is run inside a +sandboxed execution environment. Results are collected into a Trajectory. +""" +from __future__ import annotations + +import importlib +import sys +import textwrap +import traceback +import types +from pathlib import Path +from typing import Any, Callable + +from .trajectory import Step, Trajectory, ToolCall +from .verifier import Verifier + + +class TaskAgent: + """Loads a scaffold from source and runs it against dataset instances.""" + + def __init__(self, scaffold_source: str, verifier: Verifier | None = None): + self.scaffold_source = scaffold_source + self.verifier = verifier + self._module: types.ModuleType | None = None + + def _load_module(self) -> types.ModuleType: + """Compile and load the scaffold source into a fresh module.""" + mod = types.ModuleType("_sia_scaffold") + mod.__dict__["__builtins__"] = __builtins__ + exec(compile(self.scaffold_source, "", "exec"), mod.__dict__) # noqa: S102 + return mod + + def run_instance(self, instance: dict[str, Any]) -> Step: + """Run one dataset instance and return a populated Step.""" + if self._module is None: + self._module = self._load_module() + + run_fn: Callable = getattr(self._module, "run", None) + if run_fn is None: + raise AttributeError("Scaffold has no 'run(instance)' function.") + + try: + result = run_fn(instance) + answer = result.get("answer") + tool_calls = [ + ToolCall(**tc) if isinstance(tc, dict) else tc + for tc in result.get("tool_calls", []) + ] + step = Step( + instance_id=str(instance.get("id", hash(str(instance)))), + prompt=result.get("prompt", ""), + response=result.get("response", ""), + tool_calls=tool_calls, + extracted_answer=answer, + ) + except Exception: + tb = traceback.format_exc() + step = Step( + instance_id=str(instance.get("id", hash(str(instance)))), + prompt="", + response="", + extracted_answer=None, + ) + step.response = f"\n{tb}" + + if self.verifier is not None: + ground_truth = instance.get("ground_truth") + try: + step.reward = self.verifier.score(step.extracted_answer, ground_truth) + except Exception: + step.reward = 0.0 + + return step + + def run_dataset( + self, + dataset: list[dict[str, Any]], + generation: int = 0, + ) -> Trajectory: + """Run all instances and return the full trajectory τg.""" + # Reload module for each generation to pick up a fresh scaffold. + self._module = self._load_module() + traj = Trajectory(generation=generation) + for instance in dataset: + step = self.run_instance(instance) + traj.add_step(step) + traj.compute_metrics() + return traj diff --git a/SIA/tasks/__init__.py b/SIA/tasks/__init__.py new file mode 100644 index 00000000..fa53483d --- /dev/null +++ b/SIA/tasks/__init__.py @@ -0,0 +1,6 @@ +"""Task definitions for the three SIA benchmark domains.""" +from .lawbench import LawBenchTask +from .trimul import TriMulTask +from .scrnaseq import SCRNASeqTask + +__all__ = ["LawBenchTask", "TriMulTask", "SCRNASeqTask"] diff --git a/SIA/tasks/lawbench.py b/SIA/tasks/lawbench.py new file mode 100644 index 00000000..afb16bd6 --- /dev/null +++ b/SIA/tasks/lawbench.py @@ -0,0 +1,84 @@ +"""LawBench: 191-class Chinese Criminal Charge Classification (§6.3.1). + +Given a factual case summary, the model must identify the correct criminal +charge from 191 distinct categories in Chinese statutory law. + +Benchmark: Fei et al., 2023. +Metric: Top-1 accuracy on held-out test split. +Previous SOTA: 45.0% +SIA-H: 50.0% +SIA-W+H: 70.1% +""" +from __future__ import annotations + +from typing import Any + +from ..verifier import Verifier + +TASK_SPEC = """ +Task: Chinese Legal Charge Classification (LawBench 191-class) + +Input: A factual case summary in Chinese describing a criminal incident. +Output: Exactly one charge label from the 191 categories in Chinese statutory law. + +Examples of fine-grained distinctions that must be handled: +- Theft sub-types: ordinary theft (盗窃), public-property theft, embezzlement (侵占) +- Assault grades: simple assault (故意伤害), aggravated, grievous bodily harm +- Fraud variants: ordinary fraud (诈骗), wire fraud, contract fraud + +Dataset: 5,332 training / 913 test instances (all evaluations on held-out test split). +Metric: Top-1 accuracy (correct charge / total instances). +Verifier: exact string match against the gold charge label after normalisation. +""".strip() + +REFERENCE_IMPL = """ +# Minimal baseline: TF-IDF + LinearSVC pipeline (harness-discovered by SIA-H) +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.svm import LinearSVC +from sklearn.pipeline import Pipeline +import json + +def run(instance: dict) -> dict: + # NOTE: This is a reference only; a real scaffold trains the pipeline first. + return { + "answer": None, + "prompt": instance.get("text", ""), + "response": "", + "tool_calls": [], + } +""" + + +class LawBenchVerifier(Verifier): + """Exact-match verifier for LawBench charge labels.""" + + def score(self, prediction: Any, ground_truth: Any) -> float: + if prediction is None or ground_truth is None: + return 0.0 + return 1.0 if _normalise(str(prediction)) == _normalise(str(ground_truth)) else 0.0 + + +def _normalise(label: str) -> str: + return label.strip().lower() + + +class LawBenchTask: + """Task wrapper for LawBench.""" + + task_spec: str = TASK_SPEC + reference_impl: str = REFERENCE_IMPL + verifier: LawBenchVerifier = LawBenchVerifier() + previous_sota: float = 0.450 + + @staticmethod + def make_instance(text: str, charge: str, idx: int = 0) -> dict[str, Any]: + return {"id": str(idx), "text": text, "ground_truth": charge} + + @staticmethod + def sample_instances() -> list[dict[str, Any]]: + """A handful of synthetic illustrative instances (not real LawBench data).""" + return [ + {"id": "0", "text": "被告人趁被害人不备,将其钱包窃走。", "ground_truth": "盗窃"}, + {"id": "1", "text": "被告人持刀故意伤害被害人,致其轻伤。", "ground_truth": "故意伤害"}, + {"id": "2", "text": "被告人以非法占有为目的,虚构事实骗取他人财物。", "ground_truth": "诈骗"}, + ] diff --git a/SIA/tasks/scrnaseq.py b/SIA/tasks/scrnaseq.py new file mode 100644 index 00000000..e7cb9ff6 --- /dev/null +++ b/SIA/tasks/scrnaseq.py @@ -0,0 +1,126 @@ +"""MAGIC scRNA-seq Denoising: Single-Cell RNA Imputation (§6.3.3). + +MAGIC (Markov Affinity-based Graph Imputation of Cells) addresses the high +sparsity of scRNA-seq count matrices by constructing a k-nearest-neighbour +graph and diffusing expression values across graph neighbours. + +The task asks an agent to tune MAGIC's coupled hyperparameters on pancreas +scRNA-seq data (Baron et al., 2016). + +Benchmark: MAGIC (van Dijk et al., 2018). +Metric: mse_norm ∈ [0, 1], higher = better (1.0 = perfect imputation). +Previous SOTA: 0.240 +SIA-H: 0.241 [harness only] +SIA-W+H: 0.289 [harness + weight updates] + +Key SIA-W+H insight: a two-line post-processing step (np.clip + np.rint) +that rounds imputed counts to non-negative integers, enforcing a biological +invariant that the harness never generated. (§6.3.3) +""" +from __future__ import annotations + +from typing import Any + +import numpy as np + +from ..verifier import Verifier + +TASK_SPEC = """ +Task: MAGIC scRNA-seq Hyperparameter Optimisation and Imputation + +Single-cell RNA sequencing produces highly sparse count matrices (many true +non-zero counts observed as zero due to technical dropout). MAGIC imputes +missing signal by: + 1. Building a k-nearest-neighbour graph over cells. + 2. Computing Markov transition probabilities. + 3. Diffusing expression values across graph neighbours. + +Coupled hyperparameters to optimise: + k — number of neighbours (too small: overfits cell noise; too large: over-smoothing) + t — diffusion steps (controls diffusion depth) + alpha — kernel bandwidth (Gaussian kernel parameter) + +Additional preprocessing choices: + - Library-size normalisation (CPM) + - log1p transform + - Gene selection / filtering + +Dataset: Pancreas scRNA-seq (Baron et al., 2016). +Metric: mse_norm — normalised reconstruction MSE against ground truth + (higher is better; 1.0 = perfect imputation). + +Biological invariant: imputed counts must be non-negative integers. + Post-process with: imputed = np.clip(np.rint(imputed), 0, None) +""".strip() + +REFERENCE_IMPL = """ +import magic +import numpy as np + +DEFAULT_PARAMS = {"knn": 5, "t": 3, "decay": 1} + +def run(instance: dict) -> dict: + X = instance.get("X") # raw count matrix + params = instance.get("params", DEFAULT_PARAMS) + if X is None: + return {"answer": None, "prompt": str(instance), "response": "", "tool_calls": []} + try: + magic_op = magic.MAGIC(**params) + X_magic = magic_op.fit_transform(X) + # Enforce biological invariant + X_magic = np.clip(np.rint(X_magic), 0, None) + return { + "answer": X_magic, + "prompt": str(params), + "response": str(params), + "tool_calls": [], + } + except Exception as e: + return {"answer": None, "prompt": str(params), "response": str(e), "tool_calls": []} +""" + + +class SCRNASeqVerifier(Verifier): + """Computes mse_norm = 1 − MSE(prediction, ground_truth) / MSE(zeros, ground_truth). + + Higher is better; 1.0 means perfect reconstruction. + """ + + def score(self, prediction: Any, ground_truth: Any) -> float: + if prediction is None or ground_truth is None: + return 0.0 + try: + pred = np.asarray(prediction, dtype=float) + gt = np.asarray(ground_truth, dtype=float) + mse_pred = float(np.mean((pred - gt) ** 2)) + mse_zero = float(np.mean(gt ** 2)) + if mse_zero < 1e-12: + return 1.0 if mse_pred < 1e-12 else 0.0 + return float(np.clip(1.0 - mse_pred / mse_zero, 0.0, 1.0)) + except Exception: + return 0.0 + + +class SCRNASeqTask: + """Task wrapper for MAGIC scRNA-seq denoising.""" + + task_spec: str = TASK_SPEC + reference_impl: str = REFERENCE_IMPL + verifier: SCRNASeqVerifier = SCRNASeqVerifier() + previous_sota: float = 0.240 + + @staticmethod + def make_instance(X: Any, X_ground_truth: Any, idx: int = 0) -> dict[str, Any]: + return { + "id": str(idx), + "X": X, + "ground_truth": X_ground_truth, + "params": {"knn": 5, "t": 3, "decay": 1}, + } + + @staticmethod + def sample_instances() -> list[dict[str, Any]]: + rng = np.random.default_rng(42) + X = rng.poisson(lam=1.0, size=(50, 20)).astype(float) + gt = rng.poisson(lam=2.0, size=(50, 20)).astype(float) + return [{"id": "0", "X": X.tolist(), "ground_truth": gt.tolist(), "params": {"knn": 5, "t": 3, "decay": 1}}] diff --git a/SIA/tasks/trimul.py b/SIA/tasks/trimul.py new file mode 100644 index 00000000..ac4849a8 --- /dev/null +++ b/SIA/tasks/trimul.py @@ -0,0 +1,98 @@ +"""AlphaEvolve TriMul: CUDA kernel optimisation for protein structure prediction (§6.3.2). + +The triangular multiplicative update (TriMul) is a core operation in +AlphaFold2's Evoformer module. The task asks an agent to write a custom +CUDA kernel for this operation on an H100 GPU. + +Benchmark: AlphaEvolve (Novikov et al., 2025). +Metric: score = 1500 / runtime_µs (higher is faster). +Previous SOTA: 1.292 (≈ 1,161 µs) +SIA-H: 0.120 (≈ 12,483 µs) [harness only] +SIA-W+H: 1.475 (≈ 1,017 µs) [harness + weight updates] +""" +from __future__ import annotations + +from typing import Any + +from ..verifier import Verifier + +TASK_SPEC = """ +Task: CUDA Kernel Optimisation — AlphaFold2 Triangular Multiplicative Update (TriMul) + +The triangular multiplicative update kernel propagates pairwise residue-interaction +features during protein structure prediction. It is memory-bandwidth-limited due to +the triangular sparsity structure inducing warp divergence and cache misses. + +Achieving high throughput requires H100-specific knowledge: + - Tensor core scheduling + - Shared-memory tiling (fp16 or fp32 accumulation) + - Register pressure management + - Block-size selection for the H100 SM configuration + +Input: Fixed tensor shapes (the evaluation harness supplies these at runtime). +Output: A compilable Triton or CUDA kernel that performs the TriMul operation + correctly and as fast as possible. + +Metric: score = 1500 / runtime_µs (higher = faster; measured via H100 timing harness). +Verifier: H100 timing harness returning median runtime over 100 warm runs. +""".strip() + +REFERENCE_IMPL = """ +# Minimal Triton kernel stub (starting point for the Meta-Agent) +import triton +import triton.language as tl +import torch + +@triton.jit +def trimul_kernel( + a_ptr, b_ptr, g_ptr, out_ptr, + N, K, + stride_an, stride_ak, + stride_bn, stride_bk, + stride_on, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) + n_off = pid * BLOCK_N + tl.arange(0, BLOCK_N) + k_off = tl.arange(0, BLOCK_K) + a = tl.load(a_ptr + n_off[:, None] * stride_an + k_off[None, :] * stride_ak) + b = tl.load(b_ptr + n_off[:, None] * stride_bn + k_off[None, :] * stride_bk) + g = tl.load(g_ptr + n_off[:, None] * stride_an + k_off[None, :] * stride_ak) + out = tl.sum(a * b * g, axis=1) + tl.store(out_ptr + n_off * stride_on, out) + +def run(instance: dict) -> dict: + return {"answer": None, "prompt": str(instance), "response": "", "tool_calls": []} +""" + + +class TriMulVerifier(Verifier): + """Scores a kernel by its runtime; score = 1500 / runtime_µs.""" + + def score(self, prediction: Any, ground_truth: Any = None) -> float: + try: + runtime_us = float(prediction) + if runtime_us <= 0: + return 0.0 + return 1500.0 / runtime_us + except (TypeError, ValueError): + return 0.0 + + +class TriMulTask: + """Task wrapper for AlphaEvolve TriMul.""" + + task_spec: str = TASK_SPEC + reference_impl: str = REFERENCE_IMPL + verifier: TriMulVerifier = TriMulVerifier() + previous_sota: float = 1.292 + + @staticmethod + def make_instance(input_shapes: dict[str, Any], idx: int = 0) -> dict[str, Any]: + return {"id": str(idx), "input_shapes": input_shapes, "ground_truth": None} + + @staticmethod + def sample_instances() -> list[dict[str, Any]]: + return [ + {"id": "0", "input_shapes": {"N": 256, "K": 128}, "ground_truth": None}, + ] diff --git a/SIA/trajectory.py b/SIA/trajectory.py new file mode 100644 index 00000000..8f8863bf --- /dev/null +++ b/SIA/trajectory.py @@ -0,0 +1,74 @@ +"""Trajectory capture and logging for the SIA loop.""" +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field, asdict +from typing import Any + + +@dataclass +class ToolCall: + tool: str + args: dict[str, Any] + result: Any + error: str | None = None + + +@dataclass +class Step: + instance_id: str + prompt: str + response: str + tool_calls: list[ToolCall] = field(default_factory=list) + extracted_answer: Any = None + reward: float | None = None + timestamp: float = field(default_factory=time.time) + + +@dataclass +class Trajectory: + """Full execution log from running a scaffold Ag against dataset D.""" + generation: int + steps: list[Step] = field(default_factory=list) + metrics: dict[str, Any] = field(default_factory=dict) + error_log: list[str] = field(default_factory=list) + + def add_step(self, step: Step) -> None: + self.steps.append(step) + + def compute_metrics(self) -> dict[str, Any]: + rewards = [s.reward for s in self.steps if s.reward is not None] + if not rewards: + return {} + self.metrics = { + "n_instances": len(self.steps), + "n_scored": len(rewards), + "mean_reward": sum(rewards) / len(rewards), + "pass_rate": sum(1 for r in rewards if r > 0) / len(rewards), + "rewards": rewards, + } + return self.metrics + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def to_json(self) -> str: + return json.dumps(self.to_dict(), default=str) + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "Trajectory": + steps = [ + Step( + instance_id=s["instance_id"], + prompt=s["prompt"], + response=s["response"], + tool_calls=[ToolCall(**tc) for tc in s.get("tool_calls", [])], + extracted_answer=s.get("extracted_answer"), + reward=s.get("reward"), + timestamp=s.get("timestamp", 0.0), + ) + for s in d.get("steps", []) + ] + t = cls(generation=d["generation"], steps=steps, metrics=d.get("metrics", {}), error_log=d.get("error_log", [])) + return t diff --git a/SIA/verifier.py b/SIA/verifier.py new file mode 100644 index 00000000..a5062042 --- /dev/null +++ b/SIA/verifier.py @@ -0,0 +1,50 @@ +"""Verifier interface — deterministic per-instance reward computation.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class Verifier(ABC): + """Base class for task verifiers. + + A verifier scores a single model answer against ground truth and returns + a scalar reward in [0, 1]. + """ + + @abstractmethod + def score(self, prediction: Any, ground_truth: Any) -> float: + """Return a scalar reward for prediction given ground_truth.""" + + def batch_score(self, predictions: list[Any], ground_truths: list[Any]) -> list[float]: + return [self.score(p, g) for p, g in zip(predictions, ground_truths)] + + +class ExactMatchVerifier(Verifier): + """1.0 if prediction == ground_truth (after normalisation), else 0.0.""" + + def score(self, prediction: Any, ground_truth: Any) -> float: + return 1.0 if str(prediction).strip() == str(ground_truth).strip() else 0.0 + + +class FunctionVerifier(Verifier): + """Wraps an arbitrary callable as a verifier.""" + + def __init__(self, fn): + self._fn = fn + + def score(self, prediction: Any, ground_truth: Any) -> float: + return float(self._fn(prediction, ground_truth)) + + +class ThresholdVerifier(Verifier): + """Binary reward: 1.0 if a numeric metric exceeds a threshold.""" + + def __init__(self, threshold: float, higher_is_better: bool = True): + self.threshold = threshold + self.higher_is_better = higher_is_better + + def score(self, prediction: float, ground_truth: Any = None) -> float: + if self.higher_is_better: + return 1.0 if prediction >= self.threshold else prediction / self.threshold + return 1.0 if prediction <= self.threshold else self.threshold / max(prediction, 1e-9) diff --git a/SIA/weight_updates/__init__.py b/SIA/weight_updates/__init__.py new file mode 100644 index 00000000..810fcfec --- /dev/null +++ b/SIA/weight_updates/__init__.py @@ -0,0 +1,30 @@ +"""Weight update algorithms selected dynamically by the Feedback-Agent.""" +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout +from .ppo_gae import PPOWithGAE +from .grpo import GRPO +from .entropic import EntropicAdvantageWeighting +from .reinforce_kl import REINFORCEWithKL +from .best_of_n import BestOfNBC +from .dpo import DPO + +__all__ = [ + "WeightUpdateAlgorithm", + "WeightUpdateResult", + "Rollout", + "PPOWithGAE", + "GRPO", + "EntropicAdvantageWeighting", + "REINFORCEWithKL", + "BestOfNBC", + "DPO", + "ALGORITHM_REGISTRY", +] + +ALGORITHM_REGISTRY: dict[str, type[WeightUpdateAlgorithm]] = { + "ppo_gae": PPOWithGAE, + "grpo": GRPO, + "entropic": EntropicAdvantageWeighting, + "reinforce_kl": REINFORCEWithKL, + "best_of_n": BestOfNBC, + "dpo": DPO, +} diff --git a/SIA/weight_updates/base.py b/SIA/weight_updates/base.py new file mode 100644 index 00000000..5f686dc0 --- /dev/null +++ b/SIA/weight_updates/base.py @@ -0,0 +1,57 @@ +"""Base class for weight-update algorithms.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Rollout: + """A single sampled trajectory from the current policy.""" + state: str + action: str + reward: float + log_prob: float | None = None + value: float | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class WeightUpdateResult: + algorithm: str + n_rollouts: int + mean_reward_before: float + mean_reward_after: float | None + loss: float | None + adapter_path: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class WeightUpdateAlgorithm(ABC): + """Abstract base for RL / imitation-learning weight-update algorithms. + + Subclasses implement train(), which adapts a LoRA checkpoint given a + batch of rollouts and returns a WeightUpdateResult. + """ + + name: str = "base" + + def __init__(self, lora_rank: int = 32, learning_rate: float = 4e-5, **kwargs): + self.lora_rank = lora_rank + self.learning_rate = learning_rate + self.config = kwargs + + @abstractmethod + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + """Adapt model weights given rollouts and return the result.""" + + def select_when(self) -> str: + """Human-readable description of when this algorithm is appropriate.""" + return "" diff --git a/SIA/weight_updates/best_of_n.py b/SIA/weight_updates/best_of_n.py new file mode 100644 index 00000000..b4dd4d22 --- /dev/null +++ b/SIA/weight_updates/best_of_n.py @@ -0,0 +1,76 @@ +"""Best-of-N Behavioural Cloning (cold-start). + +Observed when: reward is so sparse that E[r] ≈ 0 across all rollouts and +policy gradient signal is numerically zero. + +The Feedback-Agent invokes this as a phase-zero cold-start: the top-k +rollouts by verifier score are distilled into the model via cross-entropy +loss, raising the baseline pass rate to a level where a subsequent PPO or +GRPO phase becomes viable. (§7.3) +""" +from __future__ import annotations + +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class BestOfNBC(WeightUpdateAlgorithm): + """Best-of-N behavioural cloning for cold-start on sparse rewards.""" + + name = "best_of_n" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + top_k: int = 4, + min_reward_threshold: float = 0.0, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.top_k = top_k + self.min_reward_threshold = min_reward_threshold + + def _select_demonstrations(self, rollouts: list[Rollout]) -> list[Rollout]: + """Return the top-k rollouts by reward, filtered by threshold.""" + filtered = [r for r in rollouts if r.reward > self.min_reward_threshold] + if not filtered: + filtered = rollouts + return sorted(filtered, key=lambda r: r.reward, reverse=True)[: self.top_k] + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + mean_reward_before = sum(r.reward for r in rollouts) / max(len(rollouts), 1) + demonstrations = self._select_demonstrations(rollouts) + + # Cross-entropy (behavioural cloning) loss over selected demonstrations + total_loss = 0.0 + for demo in demonstrations: + # -log π(a|s) for the demonstration action + total_loss += -demo.reward # placeholder: actual cross-entropy loss + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "n_demonstrations": len(demonstrations), + "top_k": self.top_k, + "demo_rewards": [d.reward for d in demonstrations], + }, + ) + + def select_when(self) -> str: + return ( + "Reward is so sparse that E[r] ≈ 0 across all rollouts and policy gradient " + "signal is numerically zero. Used as a phase-zero cold-start before PPO/GRPO." + ) diff --git a/SIA/weight_updates/dpo.py b/SIA/weight_updates/dpo.py new file mode 100644 index 00000000..e003bd96 --- /dev/null +++ b/SIA/weight_updates/dpo.py @@ -0,0 +1,92 @@ +"""Direct Preference Optimisation (DPO). + +Observed when: the verifier can rank outputs but not score them absolutely — +tasks with soft quality criteria where ordinal signal is reliable but +cardinal reward is not. + +Given a winning rollout y⁺ and a losing rollout y⁻, the objective + -log σ(β log πθ(y⁺)/πθ₀(y⁺) − β log πθ(y⁻)/πθ₀(y⁻)) +is minimised directly without a reward model. (§7.3) +""" +from __future__ import annotations + +import math +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class PreferencePair: + """A (winner, loser) pair derived from rollout ranking.""" + + def __init__(self, winner: Rollout, loser: Rollout): + self.winner = winner + self.loser = loser + + +class DPO(WeightUpdateAlgorithm): + """DPO weight update from ranked rollout pairs.""" + + name = "dpo" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + beta: float = 0.1, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.beta = beta + + def _build_pairs(self, rollouts: list[Rollout]) -> list[PreferencePair]: + """Sort rollouts by reward and construct (winner, loser) pairs.""" + sorted_rollouts = sorted(rollouts, key=lambda r: r.reward, reverse=True) + pairs = [] + mid = len(sorted_rollouts) // 2 + for w, l in zip(sorted_rollouts[:mid], sorted_rollouts[mid:]): + if w.reward > l.reward: + pairs.append(PreferencePair(winner=w, loser=l)) + return pairs + + def _dpo_loss(self, log_ratio_winner: float, log_ratio_loser: float) -> float: + """−log σ(β (log πθ(y⁺)/πθ₀(y⁺) − log πθ(y⁻)/πθ₀(y⁻))).""" + margin = self.beta * (log_ratio_winner - log_ratio_loser) + # σ(x) = 1 / (1 + exp(-x)) + return math.log(1.0 + math.exp(-margin)) + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + mean_reward_before = sum(r.reward for r in rollouts) / max(len(rollouts), 1) + pairs = self._build_pairs(rollouts) + + total_loss = 0.0 + for pair in pairs: + # Placeholder log-ratios (actual implementation needs model forward passes) + log_ratio_w = pair.winner.log_prob or 0.0 + log_ratio_l = pair.loser.log_prob or 0.0 + total_loss += self._dpo_loss(log_ratio_w, log_ratio_l) + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "beta": self.beta, + "n_pairs": len(pairs), + }, + ) + + def select_when(self) -> str: + return ( + "The verifier can rank outputs but not score them absolutely — tasks with " + "soft quality criteria where ordinal signal is reliable but cardinal reward is not." + ) diff --git a/SIA/weight_updates/entropic.py b/SIA/weight_updates/entropic.py new file mode 100644 index 00000000..79fc7e81 --- /dev/null +++ b/SIA/weight_updates/entropic.py @@ -0,0 +1,105 @@ +"""Entropic Advantage Weighting. + +Observed when: the reward histogram is heavily right-skewed — tasks where +correct solutions are rare but individually high-signal, such as hard +mathematical proofs or low-pass-rate code synthesis. + +Rather than zeroing out below-average rollouts, gradient mass is +redistributed via softmax with adaptive temperature β: + wi ∝ exp(ri / β) +The temperature is tuned online so that the effective sample size (ESS) +stays above a floor threshold, preventing collapse onto a single trajectory. +(§7.3; Yuksekgonul et al., 2026) +""" +from __future__ import annotations + +import math +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class EntropicAdvantageWeighting(WeightUpdateAlgorithm): + """Entropic advantage weighting with adaptive temperature.""" + + name = "entropic" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + beta_init: float = 1.0, + ess_floor: float = 0.2, + beta_min: float = 0.01, + beta_max: float = 10.0, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.beta = beta_init + self.ess_floor = ess_floor + self.beta_min = beta_min + self.beta_max = beta_max + + def _softmax_weights(self, rewards: list[float], beta: float) -> list[float]: + """wi ∝ exp(ri / β), normalised.""" + scaled = [r / beta for r in rewards] + max_s = max(scaled) + exp_vals = [math.exp(s - max_s) for s in scaled] + total = sum(exp_vals) + return [e / total for e in exp_vals] + + def _effective_sample_size(self, weights: list[float]) -> float: + """ESS = (Σwi)² / Σwi² — normalised to [0, 1].""" + n = len(weights) + sum_sq = sum(w ** 2 for w in weights) + return 1.0 / (n * sum_sq) if sum_sq > 0 else 1.0 + + def _adapt_beta(self, rewards: list[float]) -> float: + """Tune β so that ESS ≥ ess_floor.""" + beta = self.beta + for _ in range(20): + weights = self._softmax_weights(rewards, beta) + ess = self._effective_sample_size(weights) + if ess >= self.ess_floor: + break + beta = min(beta * 1.5, self.beta_max) + self.beta = max(self.beta_min, min(beta, self.beta_max)) + return self.beta + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + rewards = [r.reward for r in rollouts] + mean_reward_before = sum(rewards) / max(len(rewards), 1) + + beta = self._adapt_beta(rewards) + weights = self._softmax_weights(rewards, beta) + + total_loss = 0.0 + for rollout, w in zip(rollouts, weights): + # Weighted policy gradient: -w * log π(a|s) + total_loss += -w # placeholder: multiply by actual log-prob + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "beta": beta, + "ess": self._effective_sample_size(weights), + }, + ) + + def select_when(self) -> str: + return ( + "The reward histogram is heavily right-skewed — tasks where correct solutions " + "are rare but individually high-signal, such as hard mathematical proofs or " + "low-pass-rate code synthesis." + ) diff --git a/SIA/weight_updates/grpo.py b/SIA/weight_updates/grpo.py new file mode 100644 index 00000000..6b5a6f84 --- /dev/null +++ b/SIA/weight_updates/grpo.py @@ -0,0 +1,81 @@ +"""Group Relative Policy Optimisation (GRPO). + +Observed when: rollouts are cheap to sample and the verifier fires at episode +end — classification, short-answer, or unit-test tasks where hundreds of +completions can be scored in a single forward pass. + +Advantages are normalised within a rollout group of size G: +Âi = (ri − r̄) / σr, eliminating the value network entirely. +This halves memory and enables large parallel batches. (§7.3) +""" +from __future__ import annotations + +import math +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class GRPO(WeightUpdateAlgorithm): + """GRPO weight update.""" + + name = "grpo" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + group_size: int = 8, + kl_coef: float = 0.01, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.group_size = group_size + self.kl_coef = kl_coef + + def _group_advantages(self, rollouts: list[Rollout]) -> list[float]: + """Normalise rewards within groups of size G: Âi = (ri − r̄) / σr.""" + advantages = [] + for i in range(0, len(rollouts), self.group_size): + group = rollouts[i:i + self.group_size] + rewards = [r.reward for r in group] + mean_r = sum(rewards) / len(rewards) + std_r = math.sqrt(sum((r - mean_r) ** 2 for r in rewards) / len(rewards)) + 1e-8 + advantages.extend((r - mean_r) / std_r for r in rewards) + return advantages + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + mean_reward_before = sum(r.reward for r in rollouts) / max(len(rollouts), 1) + advantages = self._group_advantages(rollouts) + + total_loss = 0.0 + for rollout, adv in zip(rollouts, advantages): + # Policy gradient loss: -log π(a|s) * Â + policy_loss = -adv # placeholder: multiply by actual log-prob + total_loss += policy_loss + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "group_size": self.group_size, + "kl_coef": self.kl_coef, + }, + ) + + def select_when(self) -> str: + return ( + "Rollouts are cheap to sample and the verifier fires at episode end — " + "classification, short-answer, or unit-test tasks where hundreds of " + "completions can be scored in a single forward pass." + ) diff --git a/SIA/weight_updates/ppo_gae.py b/SIA/weight_updates/ppo_gae.py new file mode 100644 index 00000000..7b0322f5 --- /dev/null +++ b/SIA/weight_updates/ppo_gae.py @@ -0,0 +1,108 @@ +"""PPO with Generalised Advantage Estimation (GAE). + +Observed when: step-level rewards are dense and training stability is the +binding constraint — multi-step tool-use or long code-generation tasks where +a single catastrophic update would collapse the policy. + +A learned value head Vϕ produces per-token advantage estimates +Ât = Σ_l (γλ)^l δ_{t+l}; a clipped surrogate +min(r_t Ât, clip(r_t, 1±ε) Ât) prevents the policy from leaving the trust +region. (§7.3) +""" +from __future__ import annotations + +import math +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class PPOWithGAE(WeightUpdateAlgorithm): + """PPO + GAE weight update.""" + + name = "ppo_gae" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + clip_epsilon: float = 0.2, + gamma: float = 0.99, + lam: float = 0.95, + n_epochs: int = 4, + minibatch_size: int = 8, + value_coef: float = 0.5, + entropy_coef: float = 0.01, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.clip_epsilon = clip_epsilon + self.gamma = gamma + self.lam = lam + self.n_epochs = n_epochs + self.minibatch_size = minibatch_size + self.value_coef = value_coef + self.entropy_coef = entropy_coef + + def _compute_gae(self, rollouts: list[Rollout]) -> list[float]: + """Compute GAE advantages Ât = Σ_l (γλ)^l δ_{t+l}.""" + advantages = [] + gae = 0.0 + for rollout in reversed(rollouts): + v = rollout.value if rollout.value is not None else 0.0 + delta = rollout.reward - v + gae = delta + self.gamma * self.lam * gae + advantages.insert(0, gae) + return advantages + + def _clip_surrogate_loss(self, ratio: float, advantage: float) -> float: + """min(r_t Ât, clip(r_t, 1±ε) Ât).""" + clipped = max(1 - self.clip_epsilon, min(1 + self.clip_epsilon, ratio)) + return min(ratio * advantage, clipped * advantage) + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + mean_reward_before = sum(r.reward for r in rollouts) / max(len(rollouts), 1) + advantages = self._compute_gae(rollouts) + + # Normalise advantages + mean_adv = sum(advantages) / len(advantages) + std_adv = math.sqrt(sum((a - mean_adv) ** 2 for a in advantages) / len(advantages)) + 1e-8 + advantages = [(a - mean_adv) / std_adv for a in advantages] + + total_loss = 0.0 + for epoch in range(self.n_epochs): + for i in range(0, len(rollouts), self.minibatch_size): + batch = list(zip(rollouts[i:i + self.minibatch_size], advantages[i:i + self.minibatch_size])) + for rollout, adv in batch: + ratio = 1.0 # placeholder: exp(log_prob_new - log_prob_old) + policy_loss = -self._clip_surrogate_loss(ratio, adv) + value_loss = (rollout.reward - (rollout.value or 0.0)) ** 2 + total_loss += policy_loss + self.value_coef * value_loss + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "clip_epsilon": self.clip_epsilon, + "gamma": self.gamma, + "lam": self.lam, + "n_epochs": self.n_epochs, + }, + ) + + def select_when(self) -> str: + return ( + "Step-level rewards are dense and training stability is the binding constraint; " + "multi-step tool-use or long code-generation tasks where a single catastrophic " + "update would collapse the policy." + ) diff --git a/SIA/weight_updates/reinforce_kl.py b/SIA/weight_updates/reinforce_kl.py new file mode 100644 index 00000000..6a802fba --- /dev/null +++ b/SIA/weight_updates/reinforce_kl.py @@ -0,0 +1,77 @@ +"""REINFORCE + KL-to-base regularisation. + +Observed when: the reward is dense and the primary risk is capability +regression rather than gradient variance — fine-grained domain-adaptation +tasks where the base model is already near-capable and large parameter +movement is undesirable. + +Monte Carlo returns Rt = Σ_{t'≥t} γ^{t'-t} r_{t'} serve as advantages +directly, augmented with a penalty α KL(πθ ‖ πθ₀) against the frozen +reference. No critic, no grouping — the simplest possible training loop. +(§7.3) +""" +from __future__ import annotations + +from typing import Any + +from .base import WeightUpdateAlgorithm, WeightUpdateResult, Rollout + + +class REINFORCEWithKL(WeightUpdateAlgorithm): + """REINFORCE with KL penalty to the frozen reference policy.""" + + name = "reinforce_kl" + + def __init__( + self, + lora_rank: int = 32, + learning_rate: float = 4e-5, + gamma: float = 0.99, + kl_coef: float = 0.1, + **kwargs: Any, + ): + super().__init__(lora_rank=lora_rank, learning_rate=learning_rate, **kwargs) + self.gamma = gamma + self.kl_coef = kl_coef + + def _monte_carlo_returns(self, rollout: Rollout) -> float: + """Rt = reward for single-step rollout (no multi-step decomposition here).""" + return rollout.reward + + def train( + self, + rollouts: list[Rollout], + base_model_id: str, + adapter_path: str | None = None, + output_path: str | None = None, + ) -> WeightUpdateResult: + mean_reward_before = sum(r.reward for r in rollouts) / max(len(rollouts), 1) + + total_loss = 0.0 + for rollout in rollouts: + returns = self._monte_carlo_returns(rollout) + # Policy gradient term: -log π(a|s) * R_t + pg_loss = -returns # placeholder: multiply by actual log-prob + # KL penalty: α * KL(πθ ‖ πθ₀) — approximated as 0 without actual model + kl_penalty = self.kl_coef * 0.0 + total_loss += pg_loss + kl_penalty + + return WeightUpdateResult( + algorithm=self.name, + n_rollouts=len(rollouts), + mean_reward_before=mean_reward_before, + mean_reward_after=None, + loss=total_loss, + adapter_path=output_path, + metadata={ + "gamma": self.gamma, + "kl_coef": self.kl_coef, + }, + ) + + def select_when(self) -> str: + return ( + "The reward is dense and the primary risk is capability regression — " + "fine-grained domain-adaptation tasks where the base model is already " + "near-capable and large parameter movement is undesirable." + )