-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathgeneration.py
More file actions
207 lines (180 loc) · 10.5 KB
/
generation.py
File metadata and controls
207 lines (180 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import random
from copy import deepcopy
from typing import Dict, NamedTuple, Sequence, Optional
import numpy as np
import torch
import transformers
import shared_cache
from formatting import FormattingBase, MathFormatting
ReasoningState = NamedTuple("ReasoningState", (
("history", Sequence[int]), ("current_step_tokens_by_worker", Sequence[Sequence[int]])),)
def solve_math_2agents(*args, **kwargs):
raise NotImplementedError("""Use solve_task_2agents""")
def solve_task_2agents(
*,
problem: str,
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer,
finisher_max_new_tokens: int = 16,
fmt: Optional[FormattingBase] = None,
**reasoning_kwargs,
) -> Dict[int, str]:
"""Generate reasoning traces with 2 parallel agents, return responses (with s1-like finisher)"""
fmt = fmt if fmt is not None else MathFormatting(tokenizer)
saved_reasoning_states = generate_reasoning_2agents(
problem=problem, model=model, tokenizer=tokenizer, fmt=fmt, **reasoning_kwargs)
outputs = dict()
for budget, reasoning_state in saved_reasoning_states.items():
generated_tokens = list(reasoning_state.history)
generated_tokens.extend(tokenizer.encode(fmt.current_step_header, add_special_tokens=False))
for worker_index, worker_tokens in enumerate(reasoning_state.current_step_tokens_by_worker):
generated_tokens.extend(worker_tokens)
generated_tokens.extend(tokenizer.encode(fmt.incomplete_step + fmt.sep, add_special_tokens=False))
problem_tokens = list(tokenizer.encode(fmt.apply_chat_template(problem), add_special_tokens=False))
response = tokenizer.decode(problem_tokens + generated_tokens)
if finisher_max_new_tokens > 0 and (fmt.get_final_answer(tokenizer.decode(generated_tokens)) is None):
response = finalize_response_with_s1_finisher(
response=response, model=model, tokenizer=tokenizer, fmt=fmt, max_new_tokens=finisher_max_new_tokens
)
outputs[budget] = response
return outputs
def generate_reasoning_2agents(
*,
problem: str,
model: transformers.PreTrainedModel,
tokenizer: transformers.PreTrainedTokenizer,
fmt: FormattingBase,
max_steps: int,
save_on_steps: Sequence[int] = (),
insert_s1_collab_message_every_tokens: int = 1024,
suppress_tokens: Sequence[int] = (),
) -> Dict[int, ReasoningState]:
"""Generate reasoning traces and return snapshot for a given max_steps and any extra snapshots for save_on_steps"""
assert all(save_step <= max_steps for save_step in save_on_steps), save_on_steps
saved_states = dict()
logits_processor = get_logits_processor(model, suppress_tokens)
device = next(model.parameters()).device
tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)
tokens_since_last_wait = 0
cache_common, cache_current_step_header, cache_own_header, cache_w1, cache_w2 = (
shared_cache.CacheBlock(config=model.config) for _ in range(5))
cm = shared_cache.SharedCacheManager(cache_structure=[
[cache_common, cache_current_step_header, cache_w2, cache_own_header, cache_w1],
[cache_common, cache_current_step_header, cache_w1, cache_own_header, cache_w2],
])
# pre-fill common cache parts
with torch.inference_mode():
model(**tokenizer(fmt.apply_chat_template(problem), **tokenizer_kwargs).to(device),
use_cache=True, past_key_values=cache_common) # <-- write to common prompt
model(**tokenizer(fmt.current_step_header, **tokenizer_kwargs).to(device),
use_cache=True, past_key_values=cache_current_step_header) # <-- write to the separator after history
model(**tokenizer(fmt.current_worker_header, **tokenizer_kwargs).to(device),
use_cache=True, past_key_values=cache_own_header) # <-- write to separator between incomplete steps
# generate interdependent reasoning chains in parallel
current_step_index_by_worker = [1, 1]
current_step_tokens_by_worker = tokenizer(list(fmt.worker_prompts), add_special_tokens=False)['input_ids']
history = list()
next_inputs = tokenizer(list(fmt.worker_prompts), **tokenizer_kwargs).to(device)
for inference_step in range(max_steps):
# run model with a shared cache (batched inference)
with torch.inference_mode():
logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
logits = logits_processor(next_inputs['input_ids'], logits)
new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(
) if model.generation_config.do_sample else logits.argmax(-1)
assert len(new_tokens) == len(fmt.workers)
# process generated tokens for printing; handle step change, update next_inputs
next_input_tokens = new_tokens.unsqueeze(-1).tolist()
for worker_index, (worker_name, worker_tokens, new_token) in enumerate(
zip(fmt.workers, current_step_tokens_by_worker, new_tokens.tolist())):
worker_tokens.append(new_token)
if fmt.is_end_of_step(worker_tokens):
# worker just finished their step - add it to common history and start a new step
current_step_index_by_worker[worker_index] += 1
history.extend(worker_tokens)
worker_tokens.clear()
start_msg = fmt.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])
if tokens_since_last_wait > insert_s1_collab_message_every_tokens:
start_msg += fmt.s1_collab_message
tokens_since_last_wait = 0
worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))
cache_common.append_from(cm.cache_structure[worker_index][-1])
cm.cache_structure[worker_index][-1].clear()
next_input_tokens[worker_index] = [new_token] + worker_tokens
tokens_since_last_wait += len(next_input_tokens[worker_index])
next_inputs = tokenizer.pad(
dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)
if torch.any(new_tokens == tokenizer.eos_token_id).item():
break # at least one worker generated the end-of-sequence token, finish early
if inference_step in save_on_steps:
saved_states[inference_step] = ReasoningState(
history=list(history), current_step_tokens_by_worker=deepcopy(current_step_tokens_by_worker),
)
# if we finished early, copy the reasoning state for all subsequent budgets
for step_to_save in set(save_on_steps) | {max_steps}:
if step_to_save not in saved_states:
saved_states[step_to_save] = ReasoningState(
history=list(history), current_step_tokens_by_worker=deepcopy(current_step_tokens_by_worker),
)
return saved_states
def finalize_response_with_s1_finisher(
*, response: str, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer,
fmt: FormattingBase, max_new_tokens: int, chunk_size: int = 4096) -> str:
"""Compile a response from a reasoning state. If there is no answer yet, prompt the model to give the answer"""
with torch.inference_mode():
device = next(model.parameters()).device
response_ids = tokenizer.encode(response + fmt.s1_finisher_suffix, add_special_tokens=False)
assert isinstance(response_ids, Sequence) and isinstance(response_ids[0], int)
prefix_num_tokens = len(tokenizer.encode(response, add_special_tokens=False))
assert prefix_num_tokens < len(response_ids)
cache = transformers.DynamicCache()
# encode prompt in chunks to save memory
next_logits = None
for chunk_start in range(0, len(response_ids), chunk_size):
next_logits = model(
input_ids=torch.tensor([response_ids[chunk_start: chunk_start + chunk_size]],
device=device, dtype=torch.int64),
attention_mask=torch.ones(1, min(chunk_start + chunk_size, len(response_ids)),
device=device, dtype=torch.int64),
use_cache=True, past_key_values=cache
).logits[..., -1, :] # [batch_size(1), vocab_size]
assert cache.get_seq_length() == min(chunk_start + chunk_size, len(response_ids))
# run max_new_steps of *always greedy* output generation
next_tokens = next_logits.argmax(-1, keepdims=True) # [batch_size(1), 1]
response_ids.append(next_tokens.item())
for inference_step in range(max_new_tokens - 1):
full_mask = torch.ones(next_tokens.shape[0], len(response_ids), device=device, dtype=torch.int64)
next_logits = model(
input_ids=next_tokens, attention_mask=full_mask, use_cache=True, past_key_values=cache
).logits[..., -1, :]
next_tokens = next_logits.argmax(-1, keepdims=True) # [batch_size(1), 1]
response_ids.append(next_tokens.item())
if response_ids[-1] == tokenizer.eos_token_id:
break
if fmt.get_final_answer(tokenizer.decode(response_ids[prefix_num_tokens:])) is not None:
break # generated a valid response - finish early
response: str = tokenizer.decode(response_ids)
return response
def get_logits_processor(model: transformers.PreTrainedModel, suppress_tokens: Sequence[int] = ()):
"""Create a transformers class that post-processes model logits for nucleus sampling, banned tokens, etc"""
generation_config, model_kwargs = model._prepare_generation_config(model.generation_config)
model._prepare_special_tokens(generation_config)
device = next(model.parameters()).device
return model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=0,
encoder_input_ids=None,
prefix_allowed_tokens_fn=None,
logits_processor=transformers.LogitsProcessorList([
transformers.generation.logits_process.SuppressTokensLogitsProcessor(
suppress_tokens, device=device)]),
device=device,
model_kwargs=model_kwargs
)
def fix_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
transformers.set_seed(seed)