Source code for neurox.data.control_task
from collections import Counter
import numpy as np
[docs]def create_sequence_labeling_dataset(
train_tokens,
dev_source=None,
test_source=None,
case_sensitive=True,
sample_from="same",
):
"""
Method that prepares labels for a control task, as defined in §2.1 of `Hewitt and Liang (2019) <https://aclanthology.org/D19-1275.pdf>`
Target classes are selected randomly for each token type in the datasets.
The number of control task classes is the same as the number of classes in ``train_tokens['target']``. The distribution of control task labels can be specified.
Parameters
----------
train_tokens : dict
Dictionary containing two lists of lists representing the training set, ``source`` and ``target``. As produced by :func:`dataloader. <mymodule.MyClass.foo>`
dev_source : list, optional
List containing the ``source`` tokens from the development set, as produced by ``dev_tokens['source']``
test_source : list, optional
List containing the ``source`` tokens from the test set, as produced by ``test_tokens['source']``
case_sensitive: bool, optional
defaults to True. Sets whether the token comparison (for assigning the control task labels) is case-sensitive
or case-insensitive.
sample_from : str, optional
defaults to 'same'. The distribution from which control task labels are sampled.
'same': Labels are sampled from the same distribution as the main task labels.
'uniform': Labels are sampled from a uniform distribution.
Returns
-------
control_task_tokens : list
A list with either one, two or three elements - depending on whether
control task labels for only the train, or also dev and test set should be created.
Each element of the list is a dictionary containing two lists, ``source`` and ``target``.
The ``source`` list is the same as from the ``tokens`` input.
The ``target`` list is the list of control task labels.
"""
# compute label stats in task training data
labels_flat = [l for sublist in train_tokens["target"] for l in sublist]
label_freqs = Counter(labels_flat)
ct_labels = list(range(len(label_freqs)))
ct_label_distr = [v / sum(label_freqs.values()) for v in label_freqs.values()]
if sample_from == "uniform":
ct_label_distr = [1 / len(ct_label_distr) for i in ct_labels]
# create control task labels
word_types_to_ct_label = dict()
datasets = [train_tokens["source"]]
if dev_source is not None:
datasets.append(dev_source)
if test_source is not None:
datasets.append(test_source)
result = []
for source_dataset in datasets:
ct_target = []
for sent in source_dataset:
ct_labels_for_sent = []
for tok in sent:
tok = tok if case_sensitive else tok.lower()
if tok in word_types_to_ct_label:
ct_labels_for_sent.append(word_types_to_ct_label[tok])
else:
label_for_tok = np.random.choice(ct_labels, p=ct_label_distr)
ct_labels_for_sent.append(label_for_tok)
word_types_to_ct_label[tok] = label_for_tok
ct_target.append(ct_labels_for_sent)
assert len(source_dataset) == len(ct_target)
assert all([len(s) == len(t) for s, t in zip(source_dataset, ct_target)])
result.append({"source": source_dataset, "target": ct_target})
return result