Skip to content

v1.1.0

Choose a tag to compare

@qgallouedec qgallouedec released this 12 Apr 02:15
· 42 commits to main since this release
3179965

Features

DistillationTrainer for efficient on-policy distillation

Read the blog post: https://huggingface.co/spaces/HuggingFaceTB/trl-distillation-trainer

off_vs_on_policy_distillation yePX-mwe_1umXK5

The new DistillationTrainer implements on-policy knowledge distillation as described in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. It extends the ideas from the GKDTrainer with three key optimizations: a generation buffer that decouples the training microbatch size from the generation batch size (up to 40x speedup), external teacher server support so the teacher doesn't need to fit on training GPUs, and binary-encoded logprob payloads that shrink transfer payloads by ~5x.

from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer

dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
    lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
    remove_columns=dataset.column_names,
)

trainer = DistillationTrainer(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    teacher_model="Qwen/Qwen2.5-7B-Instruct",
    args=DistillationConfig(
        output_dir="results/distill-qwen-gsm8k",
        lmbda=1.0,                   # fully on-policy (student generates)
        beta=1.0,                    # reverse KL
        teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
    ),
    train_dataset=dataset,
)
trainer.train()

by @cmpatino in #5407, #5500 and #5501

Chunked LM head for memory-efficient log-prob computation in AsyncGRPOTrainer

AsyncGRPOTrainer now supports a chunked LM-head path that computes per-token log-probs and entropy via online logsumexp without materializing the full [N, V] logits tensor. Combined with completion_mask filtering to skip prompt tokens, this brings massive memory savings on long sequences — up to 44x lower peak-allocated memory on an 8192-token sequence:

chunk_lm_head_size Peak Alloc (GB) Reduction Wall Time (ms)
None (baseline) 18.55 1.00x 808.7
4096 0.42 44.32x 459.0
8192 0.76 24.34x 393.0

Enable it via the new chunk_lm_head_size option in AsyncGRPOConfig:

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=AsyncGRPOConfig(chunk_lm_head_size=4096),
    ...
)

Note: mutually exclusive with use_liger_kernel (both replace the LM head forward pass).

by @AmineDiro in #5349

{% generation %} support in training chat templates

SFT with assistant_only_loss=True requires chat templates to include {% generation %} / {% endgeneration %} markers so that return_assistant_tokens_mask=True produces correct masks. Very few models ship these markers natively, so users hit a cryptic error when enabling assistant-only loss with models like Qwen3, Llama 3 or GPT-OSS.

SFTTrainer now automatically swaps in a patched training chat template when the original template lacks generation markers — no manual template surgery required. Training templates are shipped for Qwen2.5, Qwen3, Llama 3 and GPT-OSS, stored as standalone .jinja files under trl/chat_templates/ for readability, diffability, and editor syntax highlighting.

from trl import SFTConfig, SFTTrainer

trainer = SFTTrainer(
    model="Qwen/Qwen3-4B",
    args=SFTConfig(assistant_only_loss=True),  # now just works
    train_dataset=dataset,
)
trainer.train()

by @qgallouedec in #5459, #5470, by @RudrenduPaul in #5493 and #5522, and by @casinca in #5484

Expanded tool-calling model support

Agent training now supports a broader family of models via native tool-call response schemas:

  • GPT-OSS (#5464)
  • GLM-4-MoE (#5463)
  • Qwen3-VL (#5469)
  • Gemma 4 — the first model to natively ship a response schema (#5454)

A new supports_tool_calling() utility detects whether a tokenizer/processor can render a full tool-calling turn, and GRPOTrainer now validates tool support at initialization — raising a clear error upfront instead of failing cryptically mid-training.

by @qgallouedec in #5462, #5464, #5463, #5469 and #5454

Multimodal tool responses for VLM training

environment_factory tool methods can now return multimodal content blocks (images + text) for VLM training. Previously, tool responses were always converted to str(result), discarding any visual information. Now tools can return content block lists with images, and the trainer handles them end-to-end through tokenization, generation, and the forward pass — including correct pixel_values plumbing.

class ScreenshotEnv:
    def take_screenshot(self) -> list[dict]:
        return [
            {"type": "image", "image": self.browser.screenshot()},
            {"type": "text", "text": "Current page state"},
        ]

The OpenEnv browsergym.py example has been migrated to this pattern, and a new carla_vlm.py example demonstrates VLM training against CARLA with camera-image tool responses.

by @sergiopaniego in #5323 and #5437, and by @qgallouedec in #5448

Built-in reward functions now log extra columns

accuracy_reward and reasoning_accuracy_reward now emit extra diagnostic columns (solution, gold_parsed, answer_parsed) via the log_extra callback introduced in v1.0.0. These show up in the rich completions table, making it much easier to debug why a reward was (or wasn't) assigned.

accuracy reward with extra columns
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=GRPOConfig(log_completions=True),
    train_dataset=dataset,
)
trainer.train()

by @qgallouedec in #5308

Other

Fixes

Deprecations and Removals

  • Deprecate keep_end truncation mode in DPOConfig and SFTConfig — will be removed in v2.0.0. Use keep_start instead. By @albertvillanova in #5465
  • Deprecate pad_token config parameter in DPOConfig, SFTConfig, and RewardConfig — will be removed in v2.0.0. Set tokenizer.pad_token directly on the processing_class instead. By @albertvillanova in #5480
  • Remove trl.experimental.judges module and all judge support from trainers. Judges were experimental, unused in practice, and llm-blender (backing PairRMJudge) was unmaintained and incompatible with transformers v5 — actively blocking v5 adoption. Everything judges did can be achieved with reward_funcs. OnlineDPOTrainer, NashMDTrainer, and XPOTrainer are now unified on reward-model scoring only. By @qgallouedec in #5485

Documentation and Examples

CI

What's Changed

New Contributors

Full Changelog: v1.0.0...v1.1.0