-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
306 lines (260 loc) · 11.6 KB
/
generate.py
File metadata and controls
306 lines (260 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import torch
from torch.nn import functional as F
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.cache_utils import DynamicCache
def top_k_logits(logits, k):
if k <= 0:
return logits
else:
values, _ = torch.topk(logits, k)
min_values = values[..., -1, None]
return torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits)
def top_p_logits(logits, p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cumulative_probs > p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
mask_indices = torch.scatter(torch.full_like(logits, False, dtype=torch.bool),
-1, sorted_indices, sorted_mask)
logits = logits.masked_fill(mask_indices, float('-inf'))
return logits
def detect_tail_copy_repetition(tokens, min_repeats=10):
"""
检测尾部重复模式:
从结尾往前找,如果结尾是一段 token 重复了 min_repeats 次及以上,就返回 True
"""
n = len(tokens)
if n < 2:
return False
# 从可能的重复单元长度 1 到 n//2 依次检查
for L in range(1, n // 2 + 1):
unit = tokens[-L:]
repeats = 1
# 在尾部连续重复检测
idx = n - L
while idx - L >= 0 and tokens[idx-L:idx] == unit:
repeats += 1
idx -= L
if repeats >= min_repeats:
return True
return False
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0):
orig_shape = logits.shape[:-1] # [batch, block]
vocab_size = logits.shape[-1]
logits = logits.reshape(-1, vocab_size) # [batch*block, vocab]
ori_probs = F.softmax(logits, dim=-1)
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
logits = top_k_logits(logits, top_k)
if top_p < 1.0:
logits = top_p_logits(logits, top_p)
probs = F.softmax(logits, dim=-1) # shape: [batch*block, vocab]
assert probs.dim() == 2
token = torch.multinomial(probs, num_samples=1) # [batch*block, 1]
token_prob = torch.gather(ori_probs, -1, token) # [batch*block, 1]
return token.view(*orig_shape), token_prob.view(*orig_shape)
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0):
orig_shape = logits.shape[:-1] # [batch, block]
vocab_size = logits.shape[-1]
logits = logits.reshape(-1, vocab_size) # [batch*block, vocab]
origin_logits = logits.clone()
# 1️⃣ 原始概率:先应用温度,然后 softmax
logits = logits / temperature if temperature != 1.0 else logits
ori_probs = F.softmax(logits, dim=-1) # 用于置信度排序
# 2️⃣ 再做 top-k / top-p 筛选用于采样
if top_k > 0:
logits = top_k_logits(logits, top_k)
if top_p < 1.0:
logits = top_p_logits(logits, top_p)
probs = F.softmax(logits, dim=-1)
assert probs.dim() == 2
# 3️⃣ multinomial 采样
token = torch.multinomial(probs, num_samples=1) # [batch*block, 1]
# 4️⃣ 获取 token 在原始概率分布里的概率
token_prob = torch.gather(ori_probs, -1, token)
return token.view(*orig_shape), token_prob.view(*orig_shape)
def get_num_transfer_tokens(block_length, steps):
base = block_length // steps
remainder = block_length % steps
num_transfer_tokens = torch.zeros(steps, dtype=torch.int64) + base
num_transfer_tokens[:remainder] += 1
return num_transfer_tokens
@torch.no_grad()
def block_diffusion_generate(
model,
inputs,
mask_id,
max_sequence_length=32768,
gen_length=128,
block_length=8,
denoising_steps=8,
temperature=1.0,
top_k=0,
top_p=1.0,
remasking_strategy='low_confidence_dynamic',
confidence_threshold=0.9,
stopping_criteria_idx=None
):
model.eval()
input_ids = inputs['input_ids']
prompt_length = input_ids.shape[1]
past_key_values = DynamicCache()
gen_length = min(gen_length, max_sequence_length - prompt_length)
if gen_length <= 0:
return input_ids
num_blocks = (prompt_length + gen_length +
block_length - 1) // block_length
total_length = num_blocks * block_length
block_mask = torch.tril(torch.ones(
num_blocks, num_blocks, device=model.device))
block_diffusion_attention_mask = block_mask.repeat_interleave(block_length, dim=0)\
.repeat_interleave(block_length, dim=1).unsqueeze(0)
position_ids = torch.arange(total_length, device=model.device).unsqueeze(0)
x = torch.full((1, total_length), mask_id,
dtype=torch.long, device=model.device)
x[:, :prompt_length] = input_ids
prefill_blocks = prompt_length // block_length
prefill_length = prefill_blocks * block_length
# Prefill stage
if prefill_length > 0:
cur_x = x[:, :prefill_length]
cur_attn_mask = block_diffusion_attention_mask[:, :prefill_length, :prefill_length]
cur_position_ids = position_ids[:, :prefill_length]
model(input_ids=cur_x,
pixel_values=inputs.get('pixel_values', None),
image_sizes=inputs.get('image_sizes', None),
pixel_values_videos=inputs.get('pixel_values_videos', None),
image_sizes_videos=inputs.get('image_sizes_videos', None),
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=True)
num_transfer_tokens = get_num_transfer_tokens(
block_length, denoising_steps)
# Decode stage
for num_block in range(prefill_blocks, num_blocks):
cur_x = x[:, num_block*block_length:(num_block+1)*block_length].clone()
cur_attn_mask = block_diffusion_attention_mask[
:, num_block*block_length:(num_block+1)*block_length, :(num_block+1)*block_length
]
cur_position_ids = position_ids[:, num_block *
block_length:(num_block+1)*block_length]
for step in range(denoising_steps + 1):
mask_index = (cur_x == mask_id)
if mask_index.sum() == 0:
# Store kv cache
model(cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=True)
break
# Denosing
logits = model(cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=False).logits
# Sampling
x0, x0_p = sample_with_temperature_topk_topp(
logits,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
# Sampling strategy
if remasking_strategy == 'sequential':
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(cur_x.shape[0]):
if mask_index[j].any():
first_mask_index = mask_index[j].nonzero(as_tuple=True)[
0].min().item()
transfer_index[j, first_mask_index:first_mask_index +
num_transfer_tokens[step]] = True
else:
raise ValueError(
"No mask tokens found in the current block.")
elif remasking_strategy == 'low_confidence_static':
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
_, idx = torch.topk(
confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
elif remasking_strategy == 'low_confidence_dynamic':
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
high_conf_mask = confidence[j] > confidence_threshold
num_high_confidence = high_conf_mask.sum()
if num_high_confidence >= num_transfer_tokens[step]:
transfer_index[j] = high_conf_mask
else:
_, idx = torch.topk(
confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
else:
raise ValueError(
f"Unknown remasking strategy: {remasking_strategy}")
cur_x[transfer_index] = x0[transfer_index]
x[:, num_block*block_length:(num_block+1)*block_length] = cur_x
generated_tokens = x[0, prompt_length:num_block*block_length+block_length].tolist()
if detect_tail_copy_repetition(generated_tokens, min_repeats=20):
print("Detected self-replicating repetition at the end, halting generation early!")
break
if stopping_criteria_idx is not None and any(stop_idx in x[:, prompt_length:] for stop_idx in stopping_criteria_idx):
break
return x
if __name__ == "__main__":
model_dir = "JetLM/SDAR-VL-Think-8B"
model_dir = "/mnt/shared-storage-user/dllm-share/chengshuang/modelZoo/SDAR-VL/SDAR-VL-Think-8B"
model = AutoModelForImageTextToText.from_pretrained(
model_dir,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map="cuda:0"
)
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Question: A circle $K$ is inscribed in a quarter circle with radius 6 as shown in the figure. What is the radius of circle $K$?\n(A) $\frac{6-\sqrt{2}}{2}$\n(B) $\frac{3 \sqrt{2}}{2}$\n(C) 2.5\n(D) 3\n(E) $6(\sqrt{2}-1)$"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/173.jpg"
raw_image = Image.open(image_file)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
# 顺序解码
denoising_steps = 4
block_length = 4
gen_length = 2048
remasking_strategy = 'low_confidence_static'
print(f"generation args: denoising_steps={denoising_steps}, block_length={block_length}, remasking_strategy={remasking_strategy}")
results = block_diffusion_generate(
model,
inputs=inputs,
mask_id=151669,
gen_length=gen_length,
block_length=denoising_steps,
denoising_steps=block_length,
temperature=1.0,
top_k=0,
top_p=1.0,
remasking_strategy=remasking_strategy,
confidence_threshold=0.9,
stopping_criteria_idx=[151645, 151643]
)
output_text = processor.decode(results[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
cleaned_text = output_text.replace('<|MASK|>', '')
print(cleaned_text)