Bolmo-7B / utils_bolmo.py
benjamin's picture
Upload folder using huggingface_hub
8b28006 verified
import math
import torch
import torch.nn.functional as F
def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
if boundary_threshold.startswith("sample:"):
_, temperature = boundary_threshold.split(":")
temperature = float(temperature)
if temperature == 0:
return (boundary_logprobs > math.log(0.5))
elif temperature == 1:
return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
else:
raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
elif boundary_threshold.startswith("topk:"):
_, topk = boundary_threshold.split(":")
topk = int(topk)
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
return (boundary_logprobs >= thresholds.unsqueeze(-1))
elif boundary_threshold.startswith("topk_percent:"):
_, topk_percent = boundary_threshold.split(":")
topk_percent = float(topk_percent)
assert 0 <= topk_percent <= 1
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
return (boundary_logprobs >= thresholds.unsqueeze(-1))
else:
raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")
def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
max_len = max(t.size(0) for t in tensors)
if multiple_of > 1:
# Round up max_len to the nearest multiple_of
max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
padded = []
for t in tensors:
if direction == "left":
pad_shape = (max_len - t.size(0), 0)
elif direction == "right":
pad_shape = (0, max_len - t.size(0))
else:
raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
padded.append(F.pad(t, pad_shape, value=value))
return torch.stack(padded, dim=0)
def pad_right(
tensors: list[torch.Tensor],
multiple_of: int = 128,
value=0,
):
return _pad(tensors, multiple_of, direction="right", value=value)
def pad_left(
tensors: list[torch.Tensor],
multiple_of: int = 128,
value=0,
):
return _pad(tensors, multiple_of, direction="left", value=value)
class MaskState:
def __init__(self, mask):
self.cpu_mask = mask.cpu()
self.mask = mask
self.inv_mask = ~mask
self._all = self.cpu_mask.all().item()
self._any = self.cpu_mask.any().item()
def any(self):
return self._any
def all(self):
return self._all
def selective_get(self, x, inv=False):
# try to avoid sync through nonzero on index
if inv:
if self.all():
return x[[]]
elif not self.any():
return x
else:
return x[self.inv_mask]
else:
if self.all():
return x
elif not self.any():
return x[[]]
else:
return x[self.mask]
def selective_put(self, x, out, inv=False):
# try to avoid sync through nonzero on index
if inv:
if self.all():
return
elif not self.any():
out.copy_(x)
else:
out[self.inv_mask] = x
else:
if self.all():
out.copy_(x)
elif not self.any():
return
else:
out[self.mask] = x
def selective_add(self, x, out, inv=False):
# try to avoid sync through nonzero on index
if inv:
if self.all():
return
elif not self.any():
out.add_(x)
else:
out[self.inv_mask] += x
else:
if self.all():
out.add_(x)
elif not self.any():
return
else:
out[self.mask] += x