mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-08 06:43:15 +00:00
fix training
This commit is contained in:
@@ -317,7 +317,7 @@ class SampleGenerateCallback(transformers.TrainerCallback):
|
||||
for sample_input in sample_inputs:
|
||||
tokenizer = kwargs['tokenizer']
|
||||
inputs = sample_input['prompt'] + sample_input['prompt_postfix']
|
||||
logger.info(f"sample input: {inputs}")
|
||||
logger.info(f"sample input: {inputs[:60]}")
|
||||
model = kwargs['model']
|
||||
input_ids = tokenizer(inputs, return_tensors="pt")['input_ids']
|
||||
input_ids = input_ids.to('cuda')
|
||||
@@ -331,7 +331,7 @@ class SampleGenerateCallback(transformers.TrainerCallback):
|
||||
|
||||
|
||||
#print(generation_output)
|
||||
logger.info(f"sample output: {tokenizer.decode(generation_output[0])}")
|
||||
logger.info(f"sample output: {tokenizer.decode(generation_output[0])[-60:]}")
|
||||
|
||||
else:
|
||||
logger.info(f"model not found in kwargs, skipping")
|
||||
@@ -380,7 +380,7 @@ def get_accelerate_model(args, checkpoint_dir):
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
config = AutoConfig.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
config.rope_scaling['factor'] = 32.0
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -907,5 +907,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
train()
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.info(f"oom: {e}", stack_info=True)
|
||||
logger.info(f"oom: {e}", exc_info=True)
|
||||
print_tensors('before oom')
|
||||
|
||||
Reference in New Issue
Block a user