-
-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathlanguage_modeling.py
More file actions
38 lines (26 loc) · 1.08 KB
/
language_modeling.py
File metadata and controls
38 lines (26 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pathlib
import torch
class TextReader:
def __init__(self, encoder, mode):
self.encoder = encoder
self.mode = mode
def read(self, dataset_meta):
"""Read the dataset of the specified mode, and return
all of its text as a list of integer indexes.
"""
# Initialize the list that will hold the encoded text
encoded_text = []
# Get the path of the file containing text data
data_path = getattr(dataset_meta, f"{self.mode}_path")
data_path = pathlib.Path(data_path)
# Open the text file to read and encode its text
with data_path.open(encoding='utf-8') as f:
# Read all lines
for line in f.readlines():
# Encode the line into integer indexes
enc_output = self.encoder(line)
# Get the list of token IDs
line_encoded = enc_output["token_ids"]
encoded_text.extend(line_encoded)
# Add the whole dataset as one example
self.encoded_text = torch.LongTensor(encoded_text)