# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from typing import List
import torch
from llama.tokenizer import Tokenizer
from llama.model import Transformer
class LLaMA:
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
bsz = len(prompts)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
input_ids = tokens[:, prev_pos:cur_pos]
logits = self.model.forward(input_ids, prev_pos)
if temperature > 0:
next_token_scores = sample_top_p_actual(input_ids, logits, top_p)
next_token_scores = sample_tail_free(input_ids, next_token_scores, 1.0)
next_token_scores = sample_typical(input_ids, next_token_scores, 1.0)
next_token_scores = sample_temperature(input_ids, next_token_scores, temperature)
next_token_scores = sample_advanced_repetition_penalty(input_ids, next_token_scores, 1024, 0.7, 1.1)
next_token_scores = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(next_token_scores, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[: len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded
# taken from Kobold and transformers so this stuff is AGPL I guess
def sample_temperature(input_ids, scores, tempt):
scores = scores / tempt
return scores
def sample_typical(input_ids, scores, typical, filter_value = -float("Inf"), min_tokens_to_keep = 1):
if filter_value >= 1.0:
return scores
probs = scores.softmax(dim=-1)
log_probs = probs.log()
neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)
entropy_deviation = (neg_entropy - log_probs).abs()
_, sorted_indices = torch.sort(entropy_deviation)
sorted_logits = probs.gather(-1, sorted_indices)
sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
min_tokens_to_keep = max(min_tokens_to_keep, 1)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float("Inf"), min_tokens_to_keep = 1):
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
def sample_advanced_repetition_penalty(input_ids, scores, penalty_range, penalty_slope, penalty):
penalty_range = int(penalty_range)
clipped_penalty_range = min(input_ids.shape[-1], penalty_range)
if penalty != 1.0:
if penalty_range > 0:
if clipped_penalty_range < input_ids.shape[1]:
input_ids = input_ids[..., -clipped_penalty_range:]
if penalty_slope != 0:
_penalty = (torch.arange(penalty_range, dtype=scores.dtype, device=scores.device)/(penalty_range - 1)) * 2. - 1
_penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
penalty = _penalty[..., -clipped_penalty_range:]
score = torch.gather(scores, 1, input_ids)
score = torch.where(score <= 0, score * penalty, score / penalty)
scores.scatter_(1, input_ids, score)
return scores
def sample_top_a(input_ids, scores, top_a, filter_value = -float("Inf"), min_tokens_to_keep = 1):
if filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
probs_max = probs[..., 0, None]
sorted_indices_to_remove = probs < probs_max * probs_max * top_a
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
def sample_tail_free(input_ids, scores, tfs, filter_value = -float("Inf"), min_tokens_to_keep = 1):
if filter_value >= 1.0:
return scores
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
probs = sorted_logits.softmax(dim=-1)
# Compute second derivative normalized CDF
d2 = probs.diff().diff().abs()
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
# Remove tokens with CDF value above the threshold (token with 0 are kept)
sorted_indices_to_remove = normalized_d2_cdf > tfs
# Centre the distribution around the cutoff as in the original implementation of the algorithm
sorted_indices_to_remove = torch.cat(
(
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
sorted_indices_to_remove,
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
),
dim=-1,
)
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores