fix training

This commit is contained in:
Yu Li
2023-09-15 21:59:24 -05:00
parent 96262af7cc
commit 8f9b024f9e

View File

@@ -317,7 +317,7 @@ class SampleGenerateCallback(transformers.TrainerCallback):
for sample_input in sample_inputs:
tokenizer = kwargs['tokenizer']
inputs = sample_input['prompt'] + sample_input['prompt_postfix']
logger.info(f"sample input: {inputs}")
logger.info(f"sample input: {inputs[:60]}")
model = kwargs['model']
input_ids = tokenizer(inputs, return_tensors="pt")['input_ids']
input_ids = input_ids.to('cuda')
@@ -331,7 +331,7 @@ class SampleGenerateCallback(transformers.TrainerCallback):
#print(generation_output)
logger.info(f"sample output: {tokenizer.decode(generation_output[0])}")
logger.info(f"sample output: {tokenizer.decode(generation_output[0])[-60:]}")
else:
logger.info(f"model not found in kwargs, skipping")
@@ -380,7 +380,7 @@ def get_accelerate_model(args, checkpoint_dir):
from transformers import AutoConfig
config = AutoConfig.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
config = AutoConfig.from_pretrained(args.model_name_or_path)
config.rope_scaling['factor'] = 32.0
model = AutoModelForCausalLM.from_pretrained(
@@ -907,5 +907,5 @@ if __name__ == "__main__":
try:
train()
except torch.cuda.OutOfMemoryError as e:
logger.info(f"oom: {e}", stack_info=True)
logger.info(f"oom: {e}", exc_info=True)
print_tensors('before oom')