File size: 4,061 Bytes
8b28006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import math

import torch
import torch.nn.functional as F


def compute_boundary_mask(boundary_logprobs: torch.Tensor, boundary_threshold: str) -> torch.Tensor:
    if boundary_threshold.startswith("sample:"):
        _, temperature = boundary_threshold.split(":")
        temperature = float(temperature)

        if temperature == 0:
            return (boundary_logprobs > math.log(0.5))
        elif temperature == 1:
            return torch.bernoulli(torch.exp(boundary_logprobs)).to(torch.bool)
        else:
            raise NotImplementedError("Temperatures outside {0,1} are not implemented yet.")
    elif boundary_threshold.startswith("topk:"):
        _, topk = boundary_threshold.split(":")
        topk = int(topk)
        thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - (topk / boundary_logprobs.shape[1]))
        return (boundary_logprobs >= thresholds.unsqueeze(-1))
    elif boundary_threshold.startswith("topk_percent:"):
        _, topk_percent = boundary_threshold.split(":")
        topk_percent = float(topk_percent)
        assert 0 <= topk_percent <= 1
        thresholds = torch.quantile(boundary_logprobs, dim=1, q=1 - topk_percent)
        return (boundary_logprobs >= thresholds.unsqueeze(-1))
    else:
        raise ValueError(f"Unknown boundary threshold: {boundary_threshold}")


def _pad(tensors: list[torch.Tensor], multiple_of: int, direction: str, value):
    max_len = max(t.size(0) for t in tensors)
    if multiple_of > 1:
        # Round up max_len to the nearest multiple_of
        max_len = ((max_len + multiple_of - 1) // multiple_of) * multiple_of
    padded = []
    for t in tensors:
        if direction == "left":
            pad_shape = (max_len - t.size(0), 0)
        elif direction == "right":
            pad_shape = (0, max_len - t.size(0))
        else:
            raise ValueError(f"Unknown direction: {direction}. Must be 'left' or 'right'.")
        padded.append(F.pad(t, pad_shape, value=value))
    return torch.stack(padded, dim=0)

def pad_right(
    tensors: list[torch.Tensor],
    multiple_of: int = 128,
    value=0,
):
    return _pad(tensors, multiple_of, direction="right", value=value)

def pad_left(
    tensors: list[torch.Tensor],
    multiple_of: int = 128,
    value=0,
):
    return _pad(tensors, multiple_of, direction="left", value=value)

class MaskState:
    def __init__(self, mask):
        self.cpu_mask = mask.cpu()

        self.mask = mask
        self.inv_mask = ~mask
        self._all = self.cpu_mask.all().item()
        self._any = self.cpu_mask.any().item()

    def any(self):
        return self._any

    def all(self):
        return self._all

    def selective_get(self, x, inv=False):
        # try to avoid sync through nonzero on index
        if inv:
            if self.all():
                return x[[]]
            elif not self.any():
                return x
            else:
                return x[self.inv_mask]
        else:
            if self.all():
                return x
            elif not self.any():
                return x[[]]
            else:
                return x[self.mask]

    def selective_put(self, x, out, inv=False):
        # try to avoid sync through nonzero on index
        if inv:
            if self.all():
                return
            elif not self.any():
                out.copy_(x)
            else:
                out[self.inv_mask] = x
        else:
            if self.all():
                out.copy_(x)
            elif not self.any():
                return
            else:
                out[self.mask] = x
    
    def selective_add(self, x, out, inv=False):
        # try to avoid sync through nonzero on index
        if inv:
            if self.all():
                return
            elif not self.any():
                out.add_(x)
            else:
                out[self.inv_mask] += x
        else:
            if self.all():
                out.add_(x)
            elif not self.any():
                return
            else:
                out[self.mask] += x