Translation Local Agreement + Cache optimization v0. Not connected yet

This commit is contained in:
Quentin Fuxa
2025-10-28 00:16:52 +01:00
parent 61edb70fff
commit 939a7ebf8b

View File

@@ -0,0 +1,159 @@
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import Tuple, Optional
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
device = model.device
bos_token_id = tokenizer.convert_tokens_to_ids("fra_Latn")
def compute_common_prefix_tokens(
prev_tokens: torch.Tensor,
new_tokens: torch.Tensor,
tokenizer: AutoTokenizer,
sep: str = " "
) -> torch.Tensor:
if prev_tokens is None or len(prev_tokens) == 0:
return new_tokens
prev_text = tokenizer.decode(prev_tokens, skip_special_tokens=True)
new_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
if not prev_text or not new_text:
return new_tokens
prev_words = prev_text.split(sep)
new_words = new_text.split(sep)
common_word_count = 0
for i in range(min(len(prev_words), len(new_words))):
if prev_words[i] == new_words[i]:
common_word_count += 1
else:
break
if common_word_count == 0:
if len(new_tokens) > 0:
return new_tokens[:1]
return new_tokens
if common_word_count == len(prev_words) and len(prev_words) == len(new_words):
return new_tokens
common_prefix_text = sep.join(new_words[:common_word_count])
for token_idx in range(1, len(new_tokens) + 1):
decoded = tokenizer.decode(new_tokens[:token_idx], skip_special_tokens=True)
if decoded == common_prefix_text:
return new_tokens[:token_idx]
if len(decoded) > len(common_prefix_text):
return new_tokens[:max(1, token_idx - 1)]
return new_tokens
def manual_generate(
encoder_outputs: torch.Tensor,
attention_mask: torch.Tensor,
forced_bos_token_id: int,
max_length: int = 50,
eos_token_id: Optional[int] = None
) -> Tuple[torch.Tensor, Optional[Tuple]]:
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
with torch.no_grad():
generated_tokens = model.generate(
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
forced_bos_token_id=forced_bos_token_id,
)
return generated_tokens, None
def continue_generation_with_cache(
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
prefix_tokens: torch.Tensor,
max_new_tokens: int = 50,
eos_token_id: Optional[int] = None
) -> torch.Tensor:
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
prefix_tokens = prefix_tokens.to(device)
if prefix_tokens.dim() == 1:
prefix_tokens = prefix_tokens.unsqueeze(0)
with torch.no_grad():
decoder_out = model.model.decoder(
input_ids=prefix_tokens,
encoder_hidden_states=encoder_hidden_states,
use_cache=True,
return_dict=True,
)
prefix_logits = model.lm_head(decoder_out.last_hidden_state)
past_key_values = decoder_out.past_key_values
next_token_id = torch.argmax(prefix_logits[:, -1, :], dim=-1).unsqueeze(-1)
if next_token_id.item() == eos_token_id:
return prefix_tokens
generated_tokens = torch.cat([prefix_tokens, next_token_id], dim=-1)
tokens_to_generate = max_new_tokens - generated_tokens.shape[1]
for _ in range(tokens_to_generate):
decoder_out = model.model.decoder(
input_ids=next_token_id,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
logits = model.lm_head(decoder_out.last_hidden_state)
next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
past_key_values = decoder_out.past_key_values
if next_token_id.item() == eos_token_id:
break
generated_tokens = torch.cat([generated_tokens, next_token_id], dim=-1)
return generated_tokens
if __name__ == "__main__":
src_text_0 = "Have you noticed how accurate"
src_text_1 = " LLM are now that GPU have became more powerful"
inputs_0 = tokenizer(src_text_0, return_tensors="pt").to(device)
encoder_outputs_0 = model.get_encoder()(**inputs_0)
translation_tokens_0, _ = manual_generate(
encoder_outputs=encoder_outputs_0,
attention_mask=inputs_0['attention_mask'],
forced_bos_token_id=bos_token_id,
)
print(f"translation 0: {tokenizer.decode(translation_tokens_0[0], skip_special_tokens=True)}")
src_text_full = src_text_0 + src_text_1
inputs_1 = tokenizer(src_text_full, return_tensors="pt").to(device)
encoder_outputs_1 = model.get_encoder()(**inputs_1)
translation_tokens_1 = continue_generation_with_cache(
encoder_hidden_states=encoder_outputs_1.last_hidden_state,
attention_mask=inputs_1['attention_mask'],
prefix_tokens=translation_tokens_0[0].clone(),
max_new_tokens=200
)
print(f"1: {tokenizer.decode(translation_tokens_1[0], skip_special_tokens=True)}")