mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-03-07 14:24:44 +00:00
31 lines
844 B
Python
31 lines
844 B
Python
from airllm import AirLLMLlama2
|
|
|
|
MAX_LENGTH = 128
|
|
# could use hugging face model repo id:
|
|
model = AirLLMLlama2("garage-bAInd/Platypus2-70B-instruct")
|
|
|
|
# or use model's local path...
|
|
#model = AirLLMLlama2("/home/ubuntu/.cache/huggingface/hub/models--garage-bAInd--Platypus2-70B-instruct/snapshots/b585e74bcaae02e52665d9ac6d23f4d0dbc81a0f")
|
|
|
|
input_text = [
|
|
'What is the capital of United States?',
|
|
#'I like',
|
|
]
|
|
|
|
input_tokens = model.tokenizer(input_text,
|
|
return_tensors="pt",
|
|
return_attention_mask=False,
|
|
truncation=True,
|
|
max_length=MAX_LENGTH,
|
|
padding=True)
|
|
|
|
generation_output = model.generate(
|
|
input_tokens['input_ids'].cuda(),
|
|
max_new_tokens=2,
|
|
use_cache=True,
|
|
return_dict_in_generate=True)
|
|
|
|
output = model.tokenizer.decode(generation_output.sequences[0])
|
|
|
|
print(output)
|