v1.1.0
Features
DistillationTrainer for efficient on-policy distillation
Read the blog post: https://huggingface.co/spaces/HuggingFaceTB/trl-distillation-trainer
The new DistillationTrainer implements on-policy knowledge distillation as described in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. It extends the ideas from the GKDTrainer with three key optimizations: a generation buffer that decouples the training microbatch size from the generation batch size (up to 40x speedup), external teacher server support so the teacher doesn't need to fit on training GPUs, and binary-encoded logprob payloads that shrink transfer payloads by ~5x.
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
remove_columns=dataset.column_names,
)
trainer = DistillationTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
teacher_model="Qwen/Qwen2.5-7B-Instruct",
args=DistillationConfig(
output_dir="results/distill-qwen-gsm8k",
lmbda=1.0, # fully on-policy (student generates)
beta=1.0, # reverse KL
teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
),
train_dataset=dataset,
)
trainer.train()by @cmpatino in #5407, #5500 and #5501
Chunked LM head for memory-efficient log-prob computation in AsyncGRPOTrainer
AsyncGRPOTrainer now supports a chunked LM-head path that computes per-token log-probs and entropy via online logsumexp without materializing the full [N, V] logits tensor. Combined with completion_mask filtering to skip prompt tokens, this brings massive memory savings on long sequences — up to 44x lower peak-allocated memory on an 8192-token sequence:
chunk_lm_head_size |
Peak Alloc (GB) | Reduction | Wall Time (ms) |
|---|---|---|---|
None (baseline) |
18.55 | 1.00x | 808.7 |
4096 |
0.42 | 44.32x | 459.0 |
8192 |
0.76 | 24.34x | 393.0 |
Enable it via the new chunk_lm_head_size option in AsyncGRPOConfig:
from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=AsyncGRPOConfig(chunk_lm_head_size=4096),
...
)Note: mutually exclusive with use_liger_kernel (both replace the LM head forward pass).
by @AmineDiro in #5349
{% generation %} support in training chat templates
SFT with assistant_only_loss=True requires chat templates to include {% generation %} / {% endgeneration %} markers so that return_assistant_tokens_mask=True produces correct masks. Very few models ship these markers natively, so users hit a cryptic error when enabling assistant-only loss with models like Qwen3, Llama 3 or GPT-OSS.
SFTTrainer now automatically swaps in a patched training chat template when the original template lacks generation markers — no manual template surgery required. Training templates are shipped for Qwen2.5, Qwen3, Llama 3 and GPT-OSS, stored as standalone .jinja files under trl/chat_templates/ for readability, diffability, and editor syntax highlighting.
from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
model="Qwen/Qwen3-4B",
args=SFTConfig(assistant_only_loss=True), # now just works
train_dataset=dataset,
)
trainer.train()by @qgallouedec in #5459, #5470, by @RudrenduPaul in #5493 and #5522, and by @casinca in #5484
Expanded tool-calling model support
Agent training now supports a broader family of models via native tool-call response schemas:
- GPT-OSS (#5464)
- GLM-4-MoE (#5463)
- Qwen3-VL (#5469)
- Gemma 4 — the first model to natively ship a response schema (#5454)
A new supports_tool_calling() utility detects whether a tokenizer/processor can render a full tool-calling turn, and GRPOTrainer now validates tool support at initialization — raising a clear error upfront instead of failing cryptically mid-training.
by @qgallouedec in #5462, #5464, #5463, #5469 and #5454
Multimodal tool responses for VLM training
environment_factory tool methods can now return multimodal content blocks (images + text) for VLM training. Previously, tool responses were always converted to str(result), discarding any visual information. Now tools can return content block lists with images, and the trainer handles them end-to-end through tokenization, generation, and the forward pass — including correct pixel_values plumbing.
class ScreenshotEnv:
def take_screenshot(self) -> list[dict]:
return [
{"type": "image", "image": self.browser.screenshot()},
{"type": "text", "text": "Current page state"},
]The OpenEnv browsergym.py example has been migrated to this pattern, and a new carla_vlm.py example demonstrates VLM training against CARLA with camera-image tool responses.
by @sergiopaniego in #5323 and #5437, and by @qgallouedec in #5448
Built-in reward functions now log extra columns
accuracy_reward and reasoning_accuracy_reward now emit extra diagnostic columns (solution, gold_parsed, answer_parsed) via the log_extra callback introduced in v1.0.0. These show up in the rich completions table, making it much easier to debug why a reward was (or wasn't) assigned.
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from trl.rewards import accuracy_reward
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=accuracy_reward,
args=GRPOConfig(log_completions=True),
train_dataset=dataset,
)
trainer.train()by @qgallouedec in #5308
Other
- Align KTO with DPO: precompute reference log probs at init by @albertvillanova in #5447
- Align KTO with DPO: reorganize
KTOConfigby @albertvillanova in #5477 - Use generic VLM key passthrough in DPO by @albertvillanova in #5468
- Make images optional in
prepare_multimodal_messagesby @albertvillanova in #5424 - Avoid image deepcopy in
prepare_multimodal_messagesby @albertvillanova in #5475 - Replace
pixel_position_idswithimage_position_idsfor Gemma 4 support by @qgallouedec in #5452 - Update vLLM minimum supported version to 0.11.0 by @albertvillanova in #5443
- Remove dead token attributes from trainers by @albertvillanova in #5483
- Remove unnecessary
isinstance(part, dict)checks in image extraction by @qgallouedec in #5439 - Simplify
_get_tool_suffix_idsby @qgallouedec in #5440 - Narrow prefix-preserving check to the actual requirement by @qgallouedec in #5458
- Remove duplicated
prepare_deepspeedby @albertvillanova in #5414
Fixes
- Fix targeting fused parameters with LoRA by @BenjaminBossan in #5430
- Fix
ImportErrorwith vllm-0.10.2 in OnlineDPO and OpenEnv by @albertvillanova in #5423 - Fix
_get_per_token_logps_and_entropiesreturn type by @kashif in #5456 - Fix SFT deprecation warning by @albertvillanova in #5466
- Fix broken validation of user-specified tokens by @albertvillanova in #5482
- Fix
prepare_multimodal_messagesnot normalizing empty string content for assistant/tool roles by @albertvillanova in #5496 - Remove redundant alignment of
pad_token_idby @albertvillanova in #5487 - Replace deprecated
huggingface-clireferences withhfby @hanouticelina in #5486 - Remove unused
truncation_modefrom experimentaltruncate_datasetby @albertvillanova in #5467 - Fix PR template check bot reopen loop by @qgallouedec in #5488
- Restrict VLM padding workaround to transformers 5.3.0 by @albertvillanova in #5503
Deprecations and Removals
- Deprecate
keep_endtruncation mode inDPOConfigandSFTConfig— will be removed in v2.0.0. Usekeep_startinstead. By @albertvillanova in #5465 - Deprecate
pad_tokenconfig parameter inDPOConfig,SFTConfig, andRewardConfig— will be removed in v2.0.0. Settokenizer.pad_tokendirectly on theprocessing_classinstead. By @albertvillanova in #5480 - Remove
trl.experimental.judgesmodule and all judge support from trainers. Judges were experimental, unused in practice, andllm-blender(backingPairRMJudge) was unmaintained and incompatible with transformers v5 — actively blocking v5 adoption. Everything judges did can be achieved withreward_funcs.OnlineDPOTrainer,NashMDTrainer, andXPOTrainerare now unified on reward-model scoring only. By @qgallouedec in #5485
Documentation and Examples
- Update "What's New": TRL v1 blog post by @qgallouedec in #5385
- New
carla_vlmOpenEnv example by @sergiopaniego in #5437 - Add code example for
completion_only_lossin SFT trainer docs by @RudrenduPaul in #5494 - Add docs and good defaults for
DistillationTrainerby @cmpatino in #5500 - Add test and docs for multimodal tool responses by @qgallouedec in #5448
- Add tests for Gemma pixel splitting by @qgallouedec in #5450
- Update docstring about tool messages in
prepare_multimodal_messagesby @albertvillanova in #5476 - Run
make precommitto fix docstring style by @albertvillanova in #5436
CI
- Pin GitHub Actions to commit SHAs by @paulinebm in #5435
- Update GitHub Action to use specific version of github-script by @qgallouedec in #5491
- Generic device support for CI tests by @kaixuanliu in #5357
- CI: Gemma 4 support by @qgallouedec in #5453
- Fix CI slow-tests cannot remove: No such file or directory by @albertvillanova in #5401
- Remove xfail for Qwen3VL CI tests by @albertvillanova in #5402
- Fix flaky CI
test_rloo[fsdp2]: replace non-deterministic xfail with skipif for transformers 5.4.0 by @albertvillanova in #5403 - Mark as strict the xfail tests with zero3 for RLOO and GRPO by @albertvillanova in #5404
- Hotfix CI: mark tests as xfail due to missing
input_idsorinputs_embedsby @albertvillanova in #5422 - Update tests to not pass
eval_strategyby @SunMarc in #5426 - Hotfix CI: mark tests as xfail with transformers dev due to
TypeError: 'NoneType' object is not iterableby @albertvillanova in #5427 - Revert hotfix CI for
TypeError: 'NoneType' object is not iterableby @albertvillanova in #5438 - Update tests with zero3 for RLOO and GRPO as xfail only with transformers >= v5 by @albertvillanova in #5420
- Hotfix CI: update skipif for
test_rloo[fsdp2]after transformers 5.5.0 release by @albertvillanova in #5442 - Remove xfail for ZeRO 2 and 3 + SFT + PEFT test by @qgallouedec in #5383
- Better test consistency RLOO vs GRPO by @qgallouedec in #5396
- Hotfix CI: mark tests as xfail with transformers dev for Llava models by @albertvillanova in #5504
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #5410
- Update "What's New": TRL v1 blog post by @qgallouedec in #5385
- Fix CI slow-tests cannot remove: No such file or directory by @albertvillanova in #5401
- Remove xfail for Qwen3VL CI tests by @albertvillanova in #5402
- Fix flaky CI test_rloo[fsdp2]: Replace non-deterministic xfail with skipif for transformers 5.4.0 by @albertvillanova in #5403
- Mark as strict the xfail tests with zero3 for RLOO and GRPO by @albertvillanova in #5404
- Remove duplicated prepare_deepspeed by @albertvillanova in #5414
- Hotfix CI: Mark tests as xfail due to missing input_ids or inputs_embeds by @albertvillanova in #5422
- Update tests to not pass
eval_strategyby @SunMarc in #5426 - Hotfix CI: Mark tests as xfail with transformers dev due to TypeError: 'NoneType' object is not iterable by @albertvillanova in #5427
- FIX CI: Targeting fused parameters with LoRA by @BenjaminBossan in #5430
- Support multimodal tool responses in
environment_factoryfor VLM training by @sergiopaniego in #5323 - 🔒 Pin GitHub Actions to commit SHAs by @paulinebm in #5435
- New carla vlm example by @sergiopaniego in #5437
- Revert hotfix CI for TypeError: 'NoneType' object is not iterable by @albertvillanova in #5438
- Run make precommit to fix docstring style by @albertvillanova in #5436
- Fix ImportError with vllm-0.10.2 in OnlineDPO and OpenEnv by @albertvillanova in #5423
- Add chunked LM head for memory-efficient log-prob computation for AsyncGRPOTrainer by @AmineDiro in #5349
- Update tests with zero3 for RLOO and GRPO as xfail only with transformers >= v5 by @albertvillanova in #5420
- Make images optional in prepare_multimodal_messages by @albertvillanova in #5424
- Hotfix CI: Update skipif for test_rloo[fsdp2] after transformers 5.5.0 release by @albertvillanova in #5442
- Update vLLM minimum supported version to 0.11.0 by @albertvillanova in #5443
- Better test consistency RLOO vs GRPO by @qgallouedec in #5396
- Align KTO with DPO: Precompute reference log probs at init by @albertvillanova in #5447
- Add support for logging extra columns in reward functions and update related tests by @qgallouedec in #5308
- Remove unnecessary
isinstance(part, dict)checks in image extraction by @qgallouedec in #5439 - Remove xfail for ZeRO 2 and 3 + SFT + PEFT test by @qgallouedec in #5383
- Replace
pixel_position_idswithimage_position_idsfor Gemma4 support by @qgallouedec in #5452 - Add test and docs for multimodal tool responses by @qgallouedec in #5448
- Add tests for Gemma pixel splitting by @qgallouedec in #5450
- Generic device support for CI tests by @kaixuanliu in #5357
- Revert speculative argument parsing and add Gemma4 agent support by @qgallouedec in #5454
- fix _get_per_token_logps_and_entropies return type by @kashif in #5456
- Deprecate keep_end truncation mode by @albertvillanova in #5465
- Fix SFT deprecation warning by @albertvillanova in #5466
- Remove unused truncation_mode from experimental truncate_dataset by @albertvillanova in #5467
- Use generic VLM key passthrough in DPO by @albertvillanova in #5468
- Narrow prefix-preserving check to the actual requirement by @qgallouedec in #5458
- Simplify
_get_tool_suffix_idsby @qgallouedec in #5440 - Update docstring about tool messages in prepare_multimodal_messages by @albertvillanova in #5476
- CI Gemma 4 support by @qgallouedec in #5453
- Move chat templates from inline strings to
.jinjafiles by @qgallouedec in #5459 - Align KTO with DPO: Reorganize KTOConfig by @albertvillanova in #5477
- Add
supports_tool_callingutility and validate tool support at init by @qgallouedec in #5462 - Add GPT-OSS tool calling support by @qgallouedec in #5464
- Add
{% generation %}support to training chat templates by @qgallouedec in #5470 - Avoid image deepcopy in prepare_multimodal_messages by @albertvillanova in #5475
- Remove dead token attributes from trainers by @albertvillanova in #5483
- Add
DistillationTrainerfor efficient on-policy distillation by @cmpatino in #5407 - Replace deprecated
huggingface-clireferences withhfby @hanouticelina in #5486 - Fix broken validation of user-specified tokens by @albertvillanova in #5482
- Deprecate pad_token config parameter by @albertvillanova in #5480
- Remove redundant alignment of pad_token_id by @albertvillanova in #5487
- Fix PR template check bot reopen loop by @qgallouedec in #5488
- feat(gpt-oss): Add
{% generation %}markers for training chat template by @casinca in #5484 - Remove the
trl.experimental.judgesmodule and all judge support from trainers by @qgallouedec in #5485 - Hotfix CI: Mark tests as xfail with transformers dev for Llava models by @albertvillanova in #5504
- Restrict VLM padding workaround to transformers 5.3.0 by @albertvillanova in #5503
- Update GitHub Action to use specific version of github-script by @qgallouedec in #5491
- [docs] Add code example for completion_only_loss in SFT trainer docs by @RudrenduPaul in #5494
- Fix prepare_multimodal_messages not normalizing empty string content for assistant/tool roles by @albertvillanova in #5496
- Add trackio support to
DistillationTrainerby @cmpatino in #5501 - feat: add Llama 3 training chat template with generation markers by @RudrenduPaul in #5493
- Add GLM-4-MoE tool calling support by @qgallouedec in #5463
- Add Qwen3-VL tool calling support by @qgallouedec in #5469
- Add docs and good defaults for
DistillationTrainerby @cmpatino in #5500 - feat: add Qwen2.5 training chat template with generation markers by @RudrenduPaul in #5522
- Release: v1.1 by @qgallouedec in #5524
New Contributors
- @BenjaminBossan made their first contribution in #5430
- @hanouticelina made their first contribution in #5486
- @RudrenduPaul made their first contribution in #5494
Full Changelog: v1.0.0...v1.1.0