Site icon TechAIApp

Preparing Data for BERT Training


“”“Process the WikiText dataset for training the BERT model. Using Hugging Face

datasets library.

“”

 

import time

import random

from typing import Iterator

 

import tokenizers

from datasets import load_dataset, Dataset

 

# path and name of each dataset

DATASETS = {

    “wikitext-2”: (“wikitext”, “wikitext-2-raw-v1”),

    “wikitext-103”: (“wikitext”, “wikitext-103-raw-v1”),

}

PATH, NAME = DATASETS[“wikitext-103”]

TOKENIZER_PATH = “wikitext-103_wordpiece.json”

 

 

def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:

    “”“Load wikitext dataset and extract text as documents”“”

    dataset = load_dataset(path, name, split=“train”)

    docs: list[list[list[int]]] = []

    for line in dataset[“text”]:

        line = line.strip()

        if not line or line.startswith(“=”):

            docs.append([])   # new document encountered

        else:

            tokens = tokenizer.encode(line).ids

            docs[1].append(tokens)

    docs = [doc for doc in docs if doc]  # remove empty documents

    return docs

 

 

def create_dataset(

    docs: list[list[list[int]]],

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    doc_repeat: int = 10,

    mask_prob: float = 0.15,

    short_seq_prob: float = 0.1,

    max_predictions_per_seq: int = 20,

) -> Iterator[dict]:

    “”“Generate samples from all documents”“”

    doc_indices = list(range(len(docs))) * doc_repeat

    for doc_idx in doc_indices:

        yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)

 

def generate_samples(

    doc_idx: int,

    all_docs: list[list[list[int]]],

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    mask_prob: float = 0.15,

    short_seq_prob: float = 0.1,

    max_predictions_per_seq: int = 20,

) -> Iterator[dict]:

    “”“Generate samples from a given document”“”

    # number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]

    target_length = max_seq_length 3

    if random.random() < short_seq_prob:

        # shorter sequence is used 10% of the time

        target_length = random.randint(2, target_length)

 

    # copy the document

    chunks = []

    for chunk in all_docs[doc_idx]:

        chunks.append(chunk)

 

    # exhaust chunks and create samples

    while chunks:

        # scan until target token length

        running_length = 0

        end = 1

        while end < len(chunks) and running_length < target_length:

            running_length += len(chunks[end1])

            end += 1

        # randomly separate the chunk into two segments

        sep = random.randint(1, end1) if end > 1 else 1

        sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]

        sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]

        # sentence B: may be from another document

        if not sentence_b or random.random() < 0.5:

            # find another document (must not be the same as doc_idx)

            b_idx = random.randint(0, len(all_docs)2)

            if b_idx >= doc_idx:

                b_idx += 1

            # sentence B starts from a random position in the new document

            sentence_b = []

            running_length = len(sentence_a)

            i = random.randint(0, len(all_docs[b_idx])1)

            while i < len(all_docs[b_idx]) and running_length < target_length:

                sentence_b.extend(all_docs[b_idx][i])

                running_length += len(all_docs[b_idx][i])

                i += 1

            is_random_next = True

            chunks = chunks[sep:]

        else:

            is_random_next = False

            chunks = chunks[end:]

        # create a sample from the pair

        yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)

 

def create_sample(

    sentence_a: list[list[int]],

    sentence_b: list[list[int]],

    is_random_next: bool,

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    mask_prob: float = 0.15,

    max_predictions_per_seq: int = 20,

) -> dict:

    “”“Create a sample from a pair of sentences”“”

    # Collect id of special tokens

    cls_id = tokenizer.token_to_id(“[CLS]”)

    sep_id = tokenizer.token_to_id(“[SEP]”)

    mask_id = tokenizer.token_to_id(“[MASK]”)

    pad_id = tokenizer.padding[“pad_id”]

    # adjust length to fit the max sequence length

    truncate_seq_pair(sentence_a, sentence_b, max_seq_length3)

    num_pad = max_seq_length len(sentence_a) len(sentence_b) 3

    # create unmodified tokens sequence

    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)

    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [1] * num_pad

    assert len(tokens) == len(seg_id) == max_seq_length

    # create the prediction targets

    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]

    random.shuffle(cand_indices)

    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))

    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))

    mlm_positions = sorted(cand_indices[:num_predictions])

    mlm_labels = []

    for i in mlm_positions:

        mlm_labels.append(tokens[i])

        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original

        if random.random() < 0.8:

            tokens[i] = mask_id

        elif random.random() < 0.5:

            tokens[i] = random.randint(4, tokenizer.get_vocab_size()1)

    # randomly mask some tokens

    ret = {

        “tokens”: tokens,

        “segment_ids”: seg_id,

        “is_random_next”: is_random_next,

        “masked_positions”: mlm_positions,

        “masked_labels”: mlm_labels,

    }

    return ret

 

 

def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:

    “”“Truncate a pair of sequences until below a maximum sequence length.”“”

    while len(sentence_a) + len(sentence_b) > max_num_tokens:

        # pick the longer sentence to remove tokens from

        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b

        # remove one token from either end in equal probabilities

        if random.random() < 0.5:

            candidate.pop(0)

        else:

            candidate.pop()

 

 

if __name__ == “__main__”:

    print(time.time(), “started”)

    tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)

    print(time.time(), “loaded tokenizer”)

    docs = create_docs(PATH, NAME, tokenizer)

    print(time.time(), “created docs with %d documents” % len(docs))

    dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer})

    print(time.time(), “created dataset from generator”)

    # Save dataset to parquet file

    dataset.to_parquet(“wikitext-103_train_data.parquet”)

    print(time.time(), “saved dataset to parquet file”)

    # Load dataset from parquet file

    dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True)

    print(time.time(), “loaded dataset from parquet file”)

    # Print a few samples

    for i, sample in enumerate(dataset):

        print(i)

        print(sample)

        print()

        if i >= 3:

            break

    print(time.time(), “finished”)



Source link

Exit mobile version