-
Notifications
You must be signed in to change notification settings - Fork 91
Expand file tree
/
Copy pathintervals_sampler.py
More file actions
430 lines (384 loc) · 17.3 KB
/
intervals_sampler.py
File metadata and controls
430 lines (384 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""
This module provides the `IntervalsSampler` class and supporting
methods.
"""
import logging
import random
from collections import namedtuple
import numpy as np
from selene_sdk.samplers.samples_batch import SamplesBatch
from ..utils import get_indices_and_probabilities
from .online_sampler import OnlineSampler
logger = logging.getLogger(__name__)
SampleIndices = namedtuple(
"SampleIndices", ["indices", "weights"])
"""
A tuple containing the indices for some samples, and a weight to
allot to each index when randomly drawing from them.
Parameters
----------
indices : list(int)
The numeric index of each sample.
weights : list(float)
The amount of weight assigned to each sample.
Attributes
----------
indices : list(int)
The numeric index of each sample.
weights : list(float)
The amount of weight assigned to each sample.
"""
# @TODO: Extend this class to work with stranded data.
class IntervalsSampler(OnlineSampler):
"""
Draws samples from pre-specified windows in the reference sequence.
Parameters
----------
reference_sequence : selene_sdk.sequences.Sequence
A reference sequence from which to create examples.
target_path : str
Path to tabix-indexed, compressed BED file (`*.bed.gz`) of genomic
coordinates mapped to the genomic features we want to predict.
features : list(str)
List of distinct features that we aim to predict.
intervals_path : str
The path to the file that contains the intervals to sample from.
In this file, each interval should occur on a separate line.
sample_negative : bool, optional
Default is `False`. This tells the sampler whether negative
examples (i.e. with no positive labels) should be drawn when
generating samples. If `True`, both negative and positive
samples will be drawn. If `False`, only samples with at least
one positive label will be drawn.
seed : int, optional
Default is 436. Sets the random seed for sampling.
validation_holdout : list(str) or float, optional
Default is `['chr6', 'chr7']`. Holdout can be regional or
proportional. If regional, expects a list (e.g. `['X', 'Y']`).
Regions must match those specified in the first column of the
tabix-indexed BED file. If proportional, specify a percentage
between (0.0, 1.0). Typically 0.10 or 0.20.
test_holdout : list(str) or float, optional
Default is `['chr8', 'chr9']`. See documentation for
`validation_holdout` for additional information.
sequence_length : int, optional
Default is 1000. Model is trained on sequences of `sequence_length`
where genomic features are annotated to the center regions of
these sequences.
center_bin_to_predict : int, optional
Default is 200. Query the tabix-indexed file for a region of
length `center_bin_to_predict`.
feature_thresholds : float [0.0, 1.0] or None, optional
Default is 0.5. The `feature_threshold` to pass to the
`GenomicFeatures` object.
mode : {'train', 'validate', 'test'}
Default is `'train'`. The mode to run the sampler in.
save_datasets : list of str
Default is `["test"]`. The list of modes for which we should
save the sampled data to file.
output_dir : str or None, optional
Default is None. The path to the directory where we should
save sampled examples for a mode. If `save_datasets` is
a non-empty list, `output_dir` must be specified. If
the path in `output_dir` does not exist it will be created
automatically.
Attributes
----------
reference_sequence : selene_sdk.sequences.Sequence
The reference sequence that examples are created from.
target : selene_sdk.targets.Target
The `selene_sdk.targets.Target` object holding the features that we
would like to predict.
sample_from_intervals : list(tuple(str, int, int))
A list of coordinates that specify the intervals we can draw
samples from.
interval_lengths : list(int)
A list of the lengths of the intervals that we can draw samples
from. The probability that we will draw a sample from an
interval is a function of that interval's length and the length
of all other intervals.
sample_negative : bool
Whether negative examples (i.e. with no positive label) should
be drawn when generating samples. If `True`, both negative and
positive samples will be drawn. If `False`, only samples with at
least one positive label will be drawn.
validation_holdout : list(str) or float
The samples to hold out for validating model performance. These
can be "regional" or "proportional". If regional, this is a list
of region names (e.g. `['chrX', 'chrY']`). These Regions must
match those specified in the first column of the tabix-indexed
BED file. If proportional, this is the fraction of total samples
that will be held out.
test_holdout : list(str) or float
The samples to hold out for testing model performance. See the
documentation for `validation_holdout` for more details.
sequence_length : int
The length of the sequences to train the model on.
bin_radius : int
From the center of the sequence, the radius in which to detect
a feature annotation in order to include it as a sample's label.
surrounding_sequence_radius : int
The length of sequence falling outside of the feature detection
bin (i.e. `bin_radius`) center, but still within the
`sequence_length`.
modes : list(str)
The list of modes that the sampler can be run in.
mode : str
The current mode that the sampler is running in. Must be one of
the modes listed in `modes`.
"""
def __init__(self,
reference_sequence,
target_path,
features,
intervals_path,
sample_negative=False,
seed=436,
validation_holdout=['chr6', 'chr7'],
test_holdout=['chr8', 'chr9'],
sequence_length=1000,
center_bin_to_predict=200,
feature_thresholds=0.5,
mode="train",
save_datasets=["test"],
output_dir=None):
"""
Constructs a new `IntervalsSampler` object.
"""
super(IntervalsSampler, self).__init__(
reference_sequence,
target_path,
features,
seed=seed,
validation_holdout=validation_holdout,
test_holdout=test_holdout,
sequence_length=sequence_length,
center_bin_to_predict=center_bin_to_predict,
feature_thresholds=feature_thresholds,
mode=mode,
save_datasets=save_datasets,
output_dir=output_dir)
self._sample_from_mode = {}
self._randcache = {}
for mode in self.modes:
self._sample_from_mode[mode] = None
self._randcache[mode] = {"cache_indices": None, "sample_next": 0}
self.sample_from_intervals = []
self.interval_lengths = []
if self._holdout_type == "chromosome":
self._partition_dataset_chromosome(intervals_path)
else:
self._partition_dataset_proportion(intervals_path)
for mode in self.modes:
self._update_randcache(mode=mode)
self.sample_negative = sample_negative
def _partition_dataset_proportion(self, intervals_path):
"""
When holdout sets are created by randomly sampling a proportion
of the data, this method is used to divide the data into
train/test/validate subsets.
Parameters
----------
intervals_path : str
The path to the file that contains the intervals to sample
from. In this file, each interval should occur on a separate
line.
"""
with open(intervals_path, 'r') as file_handle:
for line in file_handle:
cols = line.strip().split('\t')
chrom = cols[0]
start = int(cols[1])
end = int(cols[2])
self.sample_from_intervals.append((chrom, start, end))
self.interval_lengths.append(end - start)
n_intervals = len(self.sample_from_intervals)
# all indices in the intervals list are shuffled
select_indices = list(range(n_intervals))
np.random.shuffle(select_indices)
# the first section of indices is used as the validation set
n_indices_validate = int(n_intervals * self.validation_holdout)
val_indices, val_weights = get_indices_and_probabilities(
self.interval_lengths, select_indices[:n_indices_validate])
self._sample_from_mode["validate"] = SampleIndices(
val_indices, val_weights)
if self.test_holdout:
# if applicable, the second section of indices is used as the
# test set
n_indices_test = int(n_intervals * self.test_holdout)
test_indices_end = n_indices_test + n_indices_validate
test_indices, test_weights = get_indices_and_probabilities(
self.interval_lengths,
select_indices[n_indices_validate:test_indices_end])
self._sample_from_mode["test"] = SampleIndices(
test_indices, test_weights)
# remaining indices are for the training set
tr_indices, tr_weights = get_indices_and_probabilities(
self.interval_lengths, select_indices[test_indices_end:])
self._sample_from_mode["train"] = SampleIndices(
tr_indices, tr_weights)
else:
# remaining indices are for the training set
tr_indices, tr_weights = get_indices_and_probabilities(
self.interval_lengths, select_indices[n_indices_validate:])
self._sample_from_mode["train"] = SampleIndices(
tr_indices, tr_weights)
def _partition_dataset_chromosome(self, intervals_path):
"""
When holdout sets are created by selecting all samples from a
specified region (e.g. a chromosome) this method is used to
divide the data into train/test/validate subsets.
Parameters
----------
intervals_path : str
The path to the file that contains the intervals to sample
from. In this file, each interval should occur on a separate
line.
"""
for mode in self.modes:
self._sample_from_mode[mode] = SampleIndices([], [])
with open(intervals_path, 'r') as file_handle:
for index, line in enumerate(file_handle):
cols = line.strip().split('\t')
chrom = cols[0]
start = int(cols[1])
end = int(cols[2])
if chrom in self.validation_holdout:
self._sample_from_mode["validate"].indices.append(
index)
elif self.test_holdout and chrom in self.test_holdout:
self._sample_from_mode["test"].indices.append(
index)
else:
self._sample_from_mode["train"].indices.append(
index)
self.sample_from_intervals.append((chrom, start, end))
self.interval_lengths.append(end - start)
for mode in self.modes:
sample_indices = self._sample_from_mode[mode].indices
indices, weights = get_indices_and_probabilities(
self.interval_lengths, sample_indices)
self._sample_from_mode[mode] = \
self._sample_from_mode[mode]._replace(
indices=indices, weights=weights)
def _retrieve(self, chrom, position):
"""
Retrieves samples around a position in the `reference_sequence`.
Parameters
----------
chrom : str
The name of the region (e.g. "chrX", "YFP")
position : int
The position in the query region that we will search around
for samples.
Returns
-------
retrieved_seq, retrieved_targets : \
tuple(numpy.ndarray, list(numpy.ndarray))
A tuple containing the numeric representation of the
sequence centered at the query position, as well as a list
of samples within this region that met the filtering
standards.
"""
bin_start = position - self._start_radius
bin_end = position + self._end_radius
retrieved_targets = self.target.get_feature_data(
chrom, bin_start, bin_end)
if not self.sample_negative and np.sum(retrieved_targets) == 0:
logger.info("No features found in region surrounding "
"region \"{0}\" position {1}. Sampling again.".format(
chrom, position))
return None
window_start = bin_start - self.surrounding_sequence_radius
window_end = bin_end + self.surrounding_sequence_radius
strand = self.STRAND_SIDES[random.randint(0, 1)]
retrieved_seq = \
self.reference_sequence.get_encoding_from_coords(
chrom, window_start, window_end, strand)
if retrieved_seq.shape[0] == 0:
logger.info("Full sequence centered at region \"{0}\" position "
"{1} could not be retrieved. Sampling again.".format(
chrom, position))
return None
elif np.sum(retrieved_seq) / float(retrieved_seq.shape[0]) < 0.60:
logger.info("Over 30% of the bases in the sequence centered "
"at region \"{0}\" position {1} are ambiguous ('N'). "
"Sampling again.".format(chrom, position))
return None
if self.mode in self._save_datasets:
feature_indices = ';'.join(
[str(f) for f in np.nonzero(retrieved_targets)[0]])
self._save_datasets[self.mode].append(
[chrom,
window_start,
window_end,
strand,
feature_indices])
if len(self._save_datasets[self.mode]) > 200000:
self.save_dataset_to_file(self.mode)
return (retrieved_seq, retrieved_targets)
def _update_randcache(self, mode=None):
"""
Updates the cache of indices of intervals. This allows us
to randomly sample from our data without having to use a
fixed-point approach or keeping all labels in memory.
Parameters
----------
mode : str or None, optional
Default is `None`. The mode that these samples should be
used for. See `selene_sdk.samplers.IntervalsSampler.modes` for
more information.
"""
if not mode:
mode = self.mode
self._randcache[mode]["cache_indices"] = np.random.choice(
self._sample_from_mode[mode].indices,
size=len(self._sample_from_mode[mode].indices),
replace=True,
p=self._sample_from_mode[mode].weights)
self._randcache[mode]["sample_next"] = 0
def sample(self, batch_size=1):
"""
Randomly draws a mini-batch of examples and their corresponding
labels.
Parameters
----------
batch_size : int, optional
Default is 1. The number of examples to include in the
mini-batch.
Returns
-------
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
`batch_size`, :math:`L` is the sequence length, and
:math:`N` is the size of the sequence type's alphabet.
The shape of `targets` will be :math:`B \\times F`,
where :math:`F` is the number of features.
"""
sequences = np.zeros((batch_size, self.sequence_length, 4))
targets = np.zeros((batch_size, self.n_features))
n_samples_drawn = 0
while n_samples_drawn < batch_size:
sample_index = self._randcache[self.mode]["sample_next"]
if sample_index == len(self._sample_from_mode[self.mode].indices):
self._update_randcache()
sample_index = 0
rand_interval_index = \
self._randcache[self.mode]["cache_indices"][sample_index]
self._randcache[self.mode]["sample_next"] += 1
interval_info = self.sample_from_intervals[rand_interval_index]
interval_length = self.interval_lengths[rand_interval_index]
chrom = interval_info[0]
position = int(
interval_info[1] + random.uniform(0, 1) * interval_length)
retrieve_output = self._retrieve(chrom, position)
if not retrieve_output:
continue
seq, seq_targets = retrieve_output
sequences[n_samples_drawn, :, :] = seq
targets[n_samples_drawn, :] = seq_targets
n_samples_drawn += 1
return SamplesBatch(sequences, target_batch=targets)