|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor: |
|
|
if boundary_threshold.startswith("sample:"): |
|
|
_, temperature = boundary_threshold.split(":") |
|
|
temperature = float(temperature) |
|
|
|
|
|
if temperature == 0: |
|
|
return (boundary_logprobs > math.log(0.5)) |
|
|
elif temperature == 1: |
|
|
return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool) |
|
|
else: |
|
|
raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.") |
|
|
elif boundary_threshold.startswith("topk:"): |
|
|
_, topk = boundary_threshold.split(":") |
|
|
topk = int(topk) |
|
|
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1])) |
|
|
return (boundary_logprobs >= thresholds.unsqueeze(-1)) |
|
|
elif boundary_threshold.startswith("topk_percent:"): |
|
|
_, topk_percent = boundary_threshold.split(":") |
|
|
topk_percent = float(topk_percent) |
|
|
assert 0 <= topk_percent <= 1 |
|
|
thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent) |
|
|
return (boundary_logprobs >= thresholds.unsqueeze(-1)) |
|
|
else: |
|
|
raise ValueError(f"Unknown boundary threshold: {boundary_threshold}") |
|
|
|
|
|
|
|
|
def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value): |
|
|
max_len = max(t.size(0) for t in tensors) |
|
|
if multiple_of > 1: |
|
|
|
|
|
max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of |
|
|
padded = [] |
|
|
for t in tensors: |
|
|
if direction == "left": |
|
|
pad_shape = (max_len - t.size(0), 0) |
|
|
elif direction == "right": |
|
|
pad_shape = (0, max_len - t.size(0)) |
|
|
else: |
|
|
raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.") |
|
|
padded.append(F.pad(t, pad_shape, value=value)) |
|
|
return torch.stack(padded, dim=0) |
|
|
|
|
|
def pad_right( |
|
|
tensors: list[torch.Tensor], |
|
|
multiple_of: int = 128, |
|
|
value=0, |
|
|
): |
|
|
return _pad(tensors, multiple_of, direction="right", value=value) |
|
|
|
|
|
def pad_left( |
|
|
tensors: list[torch.Tensor], |
|
|
multiple_of: int = 128, |
|
|
value=0, |
|
|
): |
|
|
return _pad(tensors, multiple_of, direction="left", value=value) |
|
|
|
|
|
class MaskState: |
|
|
def __init__(self, mask): |
|
|
self.cpu_mask = mask.cpu() |
|
|
|
|
|
self.mask = mask |
|
|
self.inv_mask = ~mask |
|
|
self._all = self.cpu_mask.all().item() |
|
|
self._any = self.cpu_mask.any().item() |
|
|
|
|
|
def any(self): |
|
|
return self._any |
|
|
|
|
|
def all(self): |
|
|
return self._all |
|
|
|
|
|
def selective_get(self, x, inv=False): |
|
|
|
|
|
if inv: |
|
|
if self.all(): |
|
|
return x[[]] |
|
|
elif not self.any(): |
|
|
return x |
|
|
else: |
|
|
return x[self.inv_mask] |
|
|
else: |
|
|
if self.all(): |
|
|
return x |
|
|
elif not self.any(): |
|
|
return x[[]] |
|
|
else: |
|
|
return x[self.mask] |
|
|
|
|
|
def selective_put(self, x, out, inv=False): |
|
|
|
|
|
if inv: |
|
|
if self.all(): |
|
|
return |
|
|
elif not self.any(): |
|
|
out.copy_(x) |
|
|
else: |
|
|
out[self.inv_mask] = x |
|
|
else: |
|
|
if self.all(): |
|
|
out.copy_(x) |
|
|
elif not self.any(): |
|
|
return |
|
|
else: |
|
|
out[self.mask] = x |
|
|
|
|
|
def selective_add(self, x, out, inv=False): |
|
|
|
|
|
if inv: |
|
|
if self.all(): |
|
|
return |
|
|
elif not self.any(): |
|
|
out.add_(x) |
|
|
else: |
|
|
out[self.inv_mask] += x |
|
|
else: |
|
|
if self.all(): |
|
|
out.add_(x) |
|
|
elif not self.any(): |
|
|
return |
|
|
else: |
|
|
out[self.mask] += x |