| from __future__ import annotations |
|
|
| import csv |
| import json |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import NamedTuple |
|
|
| import numpy as np |
| import torch |
| import spacy |
| from marisa_trie import Trie |
| from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase |
|
|
| NONE_ID = "<None>" |
|
|
|
|
| @dataclass |
| class Mention: |
| kb_id: str | None |
| text: str |
| start: int |
| end: int |
| link_count: int | None |
| total_link_count: int | None |
| doc_count: int | None |
|
|
| @property |
| def span(self) -> tuple[int, int]: |
| return self.start, self.end |
|
|
| @property |
| def link_prob(self) -> float | None: |
| if self.doc_count is None or self.total_link_count is None: |
| return None |
| elif self.doc_count > 0: |
| return min(1.0, self.total_link_count / self.doc_count) |
| else: |
| return 0.0 |
|
|
| @property |
| def prior_prob(self) -> float | None: |
| if self.link_count is None or self.total_link_count is None: |
| return None |
| elif self.total_link_count > 0: |
| return min(1.0, self.link_count / self.total_link_count) |
| else: |
| return 0.0 |
|
|
| def __repr__(self): |
| return f"<Mention {self.text} -> {self.kb_id}>" |
|
|
|
|
| def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer: |
| language_obj = spacy.blank(language) |
| return language_obj.tokenizer |
|
|
|
|
| class DictionaryEntityLinker: |
| def __init__( |
| self, |
| name_trie: Trie, |
| kb_id_trie: Trie, |
| data: np.ndarray, |
| offsets: np.ndarray, |
| max_mention_length: int, |
| case_sensitive: bool, |
| min_link_prob: float | None, |
| min_prior_prob: float | None, |
| min_link_count: int | None, |
| ): |
| self._name_trie = name_trie |
| self._kb_id_trie = kb_id_trie |
| self._data = data |
| self._offsets = offsets |
| self._max_mention_length = max_mention_length |
| self._case_sensitive = case_sensitive |
|
|
| self._min_link_prob = min_link_prob |
| self._min_prior_prob = min_prior_prob |
| self._min_link_count = min_link_count |
|
|
| self._tokenizer = get_tokenizer("en") |
|
|
| @staticmethod |
| def load( |
| data_dir: str, |
| min_link_prob: float | None = None, |
| min_prior_prob: float | None = None, |
| min_link_count: int | None = None, |
| ) -> "DictionaryEntityLinker": |
| data = np.load(os.path.join(data_dir, "data.npy")) |
| offsets = np.load(os.path.join(data_dir, "offsets.npy")) |
| name_trie = Trie() |
| name_trie.load(os.path.join(data_dir, "name.trie")) |
| kb_id_trie = Trie() |
| kb_id_trie.load(os.path.join(data_dir, "kb_id.trie")) |
|
|
| with open(os.path.join(data_dir, "config.json")) as config_file: |
| config = json.load(config_file) |
|
|
| if min_link_prob is None: |
| min_link_prob = config.get("min_link_prob", None) |
|
|
| if min_prior_prob is None: |
| min_prior_prob = config.get("min_prior_prob", None) |
|
|
| if min_link_count is None: |
| min_link_count = config.get("min_link_count", None) |
|
|
| return DictionaryEntityLinker( |
| name_trie=name_trie, |
| kb_id_trie=kb_id_trie, |
| data=data, |
| offsets=offsets, |
| max_mention_length=config["max_mention_length"], |
| case_sensitive=config["case_sensitive"], |
| min_link_prob=min_link_prob, |
| min_prior_prob=min_prior_prob, |
| min_link_count=min_link_count, |
| ) |
|
|
| def detect_mentions(self, text: str) -> list[Mention]: |
| tokens = self._tokenizer(text) |
| end_offsets = frozenset(token.idx + len(token) for token in tokens) |
| if not self._case_sensitive: |
| text = text.lower() |
|
|
| ret = [] |
| cur = 0 |
| for token in tokens: |
| start = token.idx |
| if cur > start: |
| continue |
|
|
| for prefix in sorted( |
| self._name_trie.prefixes(text[start : start + self._max_mention_length]), |
| key=len, |
| reverse=True, |
| ): |
| end = start + len(prefix) |
| if end in end_offsets: |
| matched = False |
| mention_idx = self._name_trie[prefix] |
| data_start, data_end = self._offsets[mention_idx : mention_idx + 2] |
| for item in self._data[data_start:data_end]: |
| if item.size == 4: |
| kb_idx, link_count, total_link_count, doc_count = item |
| elif item.size == 1: |
| (kb_idx,) = item |
| link_count, total_link_count, doc_count = None, None, None |
| else: |
| raise ValueError("Unexpected data array format") |
|
|
| mention = Mention( |
| kb_id=self._kb_id_trie.restore_key(kb_idx), |
| text=prefix, |
| start=start, |
| end=end, |
| link_count=link_count, |
| total_link_count=total_link_count, |
| doc_count=doc_count, |
| ) |
| if item.size == 1 or ( |
| mention.link_prob >= self._min_link_prob |
| and mention.prior_prob >= self._min_prior_prob |
| and mention.link_count >= self._min_link_count |
| ): |
| ret.append(mention) |
|
|
| matched = True |
|
|
| if matched: |
| cur = end |
| break |
|
|
| return ret |
|
|
| def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]: |
| return [self.detect_mentions(text) for text in texts] |
|
|
| def save(self, data_dir: str) -> None: |
| """ |
| Save the entity linker data to the specified directory. |
| |
| Args: |
| data_dir: Directory to save the entity linker data |
| """ |
| os.makedirs(data_dir, exist_ok=True) |
|
|
| |
| np.save(os.path.join(data_dir, "data.npy"), self._data) |
| np.save(os.path.join(data_dir, "offsets.npy"), self._offsets) |
|
|
| |
| self._name_trie.save(os.path.join(data_dir, "name.trie")) |
| self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie")) |
|
|
| |
| with open(os.path.join(data_dir, "config.json"), "w") as config_file: |
| json.dump( |
| { |
| "max_mention_length": self._max_mention_length, |
| "case_sensitive": self._case_sensitive, |
| "min_link_prob": self._min_link_prob, |
| "min_prior_prob": self._min_prior_prob, |
| "min_link_count": self._min_link_count, |
| }, |
| config_file, |
| ) |
|
|
|
|
| def load_tsv_entity_vocab(file_path: str) -> dict[str, int]: |
| vocab = {} |
| with open(file_path, "r", encoding="utf-8") as file: |
| reader = csv.reader(file, delimiter="\t") |
| for row in reader: |
| vocab[row[0]] = int(row[1]) |
| return vocab |
|
|
|
|
| def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None: |
| """ |
| Save entity vocabulary to a TSV file. |
| |
| Args: |
| file_path: Path to save the entity vocabulary |
| entity_vocab: Entity vocabulary to save |
| """ |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| with open(file_path, "w", encoding="utf-8") as f: |
| writer = csv.writer(f, delimiter="\t") |
| for entity_id, idx in entity_vocab.items(): |
| writer.writerow([entity_id, idx]) |
|
|
|
|
| class _Entity(NamedTuple): |
| entity_id: int |
| start: int |
| end: int |
|
|
| @property |
| def length(self) -> int: |
| return self.end - self.start |
|
|
|
|
| def preprocess_text( |
| text: str, |
| mentions: list[Mention] | None, |
| title: str | None, |
| title_mentions: list[Mention] | None, |
| tokenizer: PreTrainedTokenizerBase, |
| entity_vocab: dict[str, int], |
| ) -> dict[str, list[int]]: |
| tokens = [] |
| entity_ids = [] |
| entity_position_ids = [] |
| if title is not None: |
| if title_mentions is None: |
| title_mentions = [] |
|
|
| title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab) |
| tokens += title_tokens + [tokenizer.sep_token] |
| for entity in title_entities: |
| entity_ids.append(entity.entity_id) |
| entity_position_ids.append(list(range(entity.start, entity.end))) |
|
|
| if mentions is None: |
| mentions = [] |
|
|
| entity_offset = len(tokens) |
| text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab) |
| tokens += text_tokens |
| for entity in text_entities: |
| entity_ids.append(entity.entity_id) |
| entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset))) |
|
|
| input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
| return { |
| "input_ids": input_ids, |
| "entity_ids": entity_ids, |
| "entity_position_ids": entity_position_ids, |
| } |
|
|
|
|
| def _tokenize_text_with_mentions( |
| text: str, |
| mentions: list[Mention], |
| tokenizer: PreTrainedTokenizerBase, |
| entity_vocab: dict[str, int], |
| ) -> tuple[list[str], list[_Entity]]: |
| """ |
| Tokenize text while preserving mention boundaries and mapping entities. |
| |
| Args: |
| text: Input text to tokenize |
| mentions: List of detected mentions in the text |
| tokenizer: Pre-trained tokenizer to use for tokenization |
| entity_vocab: Mapping from entity KB IDs to entity vocabulary indices |
| |
| Returns: |
| Tuple containing: |
| - List of tokens from the tokenized text |
| - List of _Entity objects with entity IDs and token positions |
| """ |
| target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab] |
| split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions} |
|
|
| tokens: list[str] = [] |
| cur = 0 |
| char_to_token_mapping = {} |
| for char_position in sorted(split_char_positions): |
| target_text = text[cur:char_position] |
| tokens += tokenizer.tokenize(target_text) |
| char_to_token_mapping[char_position] = len(tokens) |
| cur = char_position |
| tokens += tokenizer.tokenize(text[cur:]) |
|
|
| entities = [ |
| _Entity( |
| entity_vocab[mention.kb_id], |
| char_to_token_mapping[mention.start], |
| char_to_token_mapping[mention.end], |
| ) |
| for mention in target_mentions |
| ] |
| return tokens, entities |
|
|
|
|
| class KPRBertTokenizer(BertTokenizer): |
| vocab_files_names = { |
| **BertTokenizer.vocab_files_names, |
| "entity_linker_data_file": "entity_linker/data.npy", |
| "entity_linker_offsets_file": "entity_linker/offsets.npy", |
| "entity_linker_name_trie_file": "entity_linker/name.trie", |
| "entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie", |
| "entity_linker_config_file": "entity_linker/config.json", |
| "entity_vocab_file": "entity_vocab.tsv", |
| "entity_embeddings_file": "entity_embeddings.npy", |
| } |
| model_input_names = [ |
| "input_ids", |
| "token_type_ids", |
| "attention_mask", |
| "entity_ids", |
| "entity_position_ids", |
| ] |
|
|
| def __init__( |
| self, |
| vocab_file, |
| entity_linker_data_file: str, |
| entity_vocab_file: str, |
| entity_embeddings_file: str | None = None, |
| *args, |
| **kwargs, |
| ): |
| super().__init__(vocab_file=vocab_file, *args, **kwargs) |
| entity_linker_dir = str(Path(entity_linker_data_file).parent) |
| self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir) |
| self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file) |
| self.id_to_entity = {v: k for k, v in self.entity_to_id.items()} |
|
|
| self.entity_embeddings = None |
| if entity_embeddings_file: |
| |
| self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r") |
| if self.entity_embeddings.shape[0] != len(self.entity_to_id): |
| raise ValueError( |
| f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match " |
| f"the number of entities {len(self.entity_to_id)}. " |
| "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent." |
| ) |
|
|
| def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]: |
| mentions = self.entity_linker.detect_mentions(text) |
| model_inputs = preprocess_text( |
| text=text, |
| mentions=mentions, |
| title=None, |
| title_mentions=None, |
| tokenizer=self, |
| entity_vocab=self.entity_to_id, |
| ) |
|
|
| |
| |
| |
| |
| prepared_inputs = self.prepare_for_model( |
| model_inputs["input_ids"], |
| **{k: v for k, v in kwargs.items() if k != "return_tensors"}, |
| ) |
| model_inputs.update(prepared_inputs) |
|
|
| |
| if kwargs.get("add_special_tokens", True): |
| if prepared_inputs["input_ids"][0] != self.cls_token_id: |
| raise ValueError( |
| "We assume that the input IDs start with the [CLS] token with add_special_tokens = True." |
| ) |
| |
| model_inputs["entity_position_ids"] = [ |
| [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"] |
| ] |
|
|
| |
| if not model_inputs["entity_ids"]: |
| model_inputs["entity_ids"] = [0] |
| model_inputs["entity_position_ids"] = [[0]] |
|
|
| |
| num_special_tokens_at_end = 0 |
| input_ids = prepared_inputs["input_ids"] |
| if isinstance(input_ids, torch.Tensor): |
| input_ids = input_ids.tolist() |
| for input_id in input_ids[::-1]: |
| if int(input_id) not in { |
| self.sep_token_id, |
| self.pad_token_id, |
| self.cls_token_id, |
| }: |
| break |
| num_special_tokens_at_end += 1 |
|
|
| |
| max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end |
| entity_indices_to_keep = list() |
| for i, position_ids in enumerate(model_inputs["entity_position_ids"]): |
| if len(position_ids) > 0 and max(position_ids) < max_effective_pos: |
| entity_indices_to_keep.append(i) |
| model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep] |
| model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep] |
|
|
| if self.entity_embeddings is not None: |
| model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32) |
| return model_inputs |
|
|
| def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding: |
| for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]: |
| if unsupported_arg in kwargs: |
| raise ValueError( |
| f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. " |
| "This tokenizer only supports single text inputs. " |
| ) |
|
|
| if isinstance(text, str): |
| processed_inputs = self._preprocess_text(text, **kwargs) |
| return BatchEncoding( |
| processed_inputs, |
| tensor_type=kwargs.get("return_tensors", None), |
| prepend_batch_axis=True, |
| ) |
|
|
| processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text] |
| collated_inputs = { |
| key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys() |
| } |
| if kwargs.get("padding"): |
| collated_inputs = self.pad( |
| collated_inputs, |
| padding=kwargs["padding"], |
| max_length=kwargs.get("max_length"), |
| pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), |
| return_attention_mask=kwargs.get("return_attention_mask"), |
| verbose=kwargs.get("verbose", True), |
| ) |
| |
| max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"]) |
| for entity_ids in collated_inputs["entity_ids"]: |
| entity_ids += [0] * (max_num_entities - len(entity_ids)) |
| |
| flattened_entity_length = [ |
| len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list |
| ] |
| max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0 |
| for entity_position_ids_list in collated_inputs["entity_position_ids"]: |
| |
| for entity_position_ids in entity_position_ids_list: |
| entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids)) |
| |
| entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * ( |
| max_num_entities - len(entity_position_ids_list) |
| ) |
| |
| if "entity_embeds" in collated_inputs: |
| for i in range(len(collated_inputs["entity_embeds"])): |
| collated_inputs["entity_embeds"][i] = np.pad( |
| collated_inputs["entity_embeds"][i], |
| pad_width=( |
| ( |
| 0, |
| max_num_entities - len(collated_inputs["entity_embeds"][i]), |
| ), |
| (0, 0), |
| ), |
| mode="constant", |
| constant_values=0, |
| ) |
| return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None)) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: |
| os.makedirs(save_directory, exist_ok=True) |
| saved_files = list(super().save_vocabulary(save_directory, filename_prefix)) |
|
|
| |
| entity_linker_save_dir = str( |
| Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent |
| ) |
| self.entity_linker.save(entity_linker_save_dir) |
| for file_name in self.vocab_files_names.values(): |
| if file_name.startswith("entity_linker/"): |
| saved_files.append(file_name) |
|
|
| |
| entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"]) |
| save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id) |
| saved_files.append(self.vocab_files_names["entity_vocab_file"]) |
|
|
| if self.entity_embeddings is not None: |
| entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"]) |
| np.save(entity_embeddings_path, self.entity_embeddings) |
| saved_files.append(self.vocab_files_names["entity_embeddings_file"]) |
| return tuple(saved_files) |
|
|