-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathdpo_multi_lora.py
More file actions
216 lines (172 loc) · 10 KB
/
dpo_multi_lora.py
File metadata and controls
216 lines (172 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""DPO (Direct Preference Optimization) Training with MultiLoRA (Megatron Backend).
MultiLoRA-based DPO training: uses the base model (without LoRA adapter) as reference
model by calling forward_only with disable_lora=True. This eliminates the need for
a separate reference model GPU group.
Uses Megatron backend with MultiLoRAMegatronModel for efficient multi-tenant LoRA training.
Pipeline:
1. Load preference dataset with chosen/rejected pairs.
2. Encode positive and negative separately.
3. Compute reference model log probabilities using base model (disable_lora=True).
4. Train policy model (with LoRA adapter) using DPO loss.
Architecture (Ray - Single Group):
┌─────────────────────────────────────────────────────────────────┐
│ Driver (CPU) │
│ dataloader ──► batched preference pairs │
│ policy_model.forward_only(disable_lora=True) ──► ref logps │
│ policy_model.forward_backward() ──► DPO loss + gradient │
└─────────────────────────────────────────────────────────────────┘
│
PolicyModel (with LoRA adapter)
- forward_only(disable_lora=True) → base model inference (reference)
- forward_backward() → LoRA adapter training (policy)
DPO data format (after preprocessing):
- positive: List[Trajectory] - chosen responses
- negative: List[Trajectory] - rejected responses
Environment variables (all optional):
MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
MODEL_GPUS – GPUs for policy model (default: 2)
BATCH_SIZE – global batch size (preference pairs) (default: 8)
MAX_STEPS – total optimization steps (default: 1000)
LR – learning rate (default: 1e-4)
DPO_BETA – DPO temperature parameter (default: 0.1)
LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid)
SAVE_STEPS – checkpoint save interval (default: 100)
MAX_LENGTH – max sequence length (default: 2048)
"""
import os
from typing import Any, Dict, List, Optional
from peft import LoraConfig
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
from twinkle.data_format import Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.loss import DPOLoss
from twinkle.metric import DPOMetric
from twinkle.preprocessor import EmojiDPOProcessor
from twinkle.processor import InputProcessor
logger = get_logger()
# ── Configuration ─────────────────────────────────────────────────────────────
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
ADAPTER_NAME = 'default_0'
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
def create_dpo_dataset():
"""Create DPO dataset with positive/negative format."""
dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(50)))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
dataset.map(
EmojiDPOProcessor,
init_args={
'system': SYSTEM_PROMPT,
}
)
# DPO preprocessor returns {'positive': [...], 'negative': [...]}
# batch_encode handles this format automatically
dataset.encode(load_from_cache_file=True)
return dataset
def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare DPO batch: reorganize batch for training with DP-safe interleaving.
Args:
batch: List of rows, each with 'positive' and 'negative' InputFeatures
and other fields (question, etc.)
Returns:
List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
worker gets complete positive/negative pairs after slicing.
Each item contains all original fields plus the InputFeature fields.
"""
result = []
for row in batch:
# Get base fields (excluding positive/negative)
base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
# Positive sample: merge base fields with positive InputFeature
pos_sample = {**base_fields, **row['positive']}
# Negative sample: merge base fields with negative InputFeature
neg_sample = {**base_fields, **row['negative']}
# Interleave: [pos, neg] per pair for DP-safe slicing
result.append(pos_sample)
result.append(neg_sample)
return result
# ── Main Training Loop ────────────────────────────────────────────────────────
def main():
# Set up device groups - only one group for LoRA training
device_groups = [
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
]
# Configure device mesh for MultiLoRA Megatron: dp=2, pp=1
from twinkle.model import MultiLoraMegatronModel
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=1)
ModelClass = MultiLoraMegatronModel
twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
# ── DataLoader Setup ──────────────────────────────────────────────────────
dataloader = DataLoader(
dataset=create_dpo_dataset,
batch_size=BATCH_SIZE,
min_batch_size=BATCH_SIZE,
device_mesh=policy_mesh,
)
# ── Policy Model Setup with LoRA ──────────────────────────────────────────
lora_config = LoraConfig(
target_modules='all-linear',
r=8,
lora_alpha=32,
lora_dropout=0.05,
)
policy_model = ModelClass(
model_id=MODEL_ID,
device_mesh=policy_mesh,
remote_group='policy',
)
MAX_STEPS = len(dataloader)
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
# Configure optimizer based on backend
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
# Set up loss function and metrics
loss_fn = DPOLoss(
beta=DPO_BETA,
loss_type=LOSS_TYPE,
reference_free=False, # We use base model as reference via disable_lora=True
sft_weight=SFT_WEIGHT,
)
policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME)
policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME)
policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
policy_model.set_template('Qwen3_5Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
optim_step = 0
backend_name = 'MultiLoRA Megatron'
logger.info(get_device_placement())
logger.info(f'Starting MultiLoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}')
logger.info(f'Using base model (disable_lora=True) as reference model')
# ── Training Loop ─────────────────────────────────────────────────────────
for batch in dataloader:
# batch is List[Dict] with 'positive' and 'negative' keys
dpo_batch = prepare_dpo_batch(batch)
# Get reference outputs using base model (without LoRA adapter)
# disable_lora=True tells the model to skip LoRA and use base weights
ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True, adapter_name=ADAPTER_NAME)
policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs, adapter_name=ADAPTER_NAME)
policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
optim_step += 1
# Logging
if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
# Checkpointing
if optim_step % SAVE_STEPS == 0:
policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
# ── Save Final Checkpoint ─────────────────────────────────────────────────
logger.info(f'Training completed. Total steps: {optim_step}')
policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME)
if __name__ == '__main__':
main()