Experimental data sets for drug discovery are sometimes limited in size, due to the difficulty of gathering this type of data. Drug discovery data sets are expensive to obtain, and some are the result of clinical trials, which might not be repeatable for ethical reasons. The ClinTox data set, for example, is comprised of data from FDA clinical trials of drug candidates, where some data sets are derived from failures, due to toxic side effects [2].For cases where training data is scarce, application of one-shot learning methods have demonstrated significantly improved performance over methods consisting only of graphical convolution networks.The performance of one-shot network architectures will be discussed here for several drug discovery data sets, which are described in Table 1.
These data sets, along with one-shot learning methods, have been integrated into the DeepChem deep learning framework, as a result of research published by Altae-Tran, et al. [1].While data remains scarce for some problem domains, such as drug discovery, one-shot learning methods could pose an important alternative network architecture, which can possibly far outperform methods which use only graphical convolution.
Dataset | Category | Description | Network Type | Number of Tasks | Compounds |
---|---|---|---|---|---|
Tox21 | Physiology | toxicity | Classification | 12 | 8,014 |
SIDER | Physiology | side reactions | Classification | 27 | 1,427 |
MUV | Biophysics | bioactivity | Classification | 17 | 93,127 |
Table 1. DeepChem drug discovery data sets investigated with one-shot learning.
One-Shot Network Architectures Produce Most Accurate Results When Applied to Small Population Training Sets
The original motivation for investigating one-shot neural networks arose from the fact that humans can learn sufficient representations, given small amounts of data, and then later apply a learned representation to correctly distinguish between objects which have been observed only once.The one-shot network architecture has previously been developed, and applied to image data, with this motivational context in mind [3, 5].
The question arose, as to whether an artificial neural network, given a small data set, could similarly learn a sufficient number of features through training, and perform at a satisfactory level.After some period of development, one-shot learning methods have emerged to demonstrate good success [3,4,5,6].
The description provided here of the one-shot approach focuses mostly on methodology, and less on the theoretical and experimental results which support the method.The simplest one-shot method computes a distance weighted combination of support set labels.The distance metric can be defined using a structure called a Siamese network, where two identical networks are used.The first twin produces a vector output for the molecule being queried, while the other twin produces a vector representing an element of the support set.Any difference between the outputs can be interpreted as a dissimilarity measure between the query structure and any particular structure in the support set.A separate dissimilarity measure can be computed for each element in the support set, and then a normalized, weighted representation of the query structure can be determined.For example, if the query structure is significantly less dissimilar to two structures, out of, say, twenty, in the support set, then the weighted representation will be nearly the average of the vectors which represent the two support structures which most resemble the queried structure.
There are two other one-shot methods which take more complex considerations into account.In the Siamese network one-shot approach, the vector embeddings of both the query structure and each individual support structure is computed independently of the support set.However, it has been shown empirically, that by taking into account the context of all support set elements, when computing the vector embeddings of the query, and each individual support structure, better one-shot network performance can be realized.This approach is called full context embedding, since the full context of the support set is taken into account when computing every vector embedding.In the full context embedding approach, the embeddings for the every support structure are allowed to influence the embedding of the query structure.
The full context embedding approach uses Siamese, i.e. matching, networks like before, but once the embeddings are computed, they are then further processed by Long Short-Term Memory (LSTM) structures.The embeddings, before processing by LSTM structures, will be referred to here as pre-contextualized vectors. The full contextual embeddings for the support structures are produced using an LSTM structure called a bidirectional LSTM (biLSTM), while the full contextual embedding for the query structure is produced by an LSTM structure called an attentional LSTM (attLSTM).An LSTM is a type of recurring neural network, which can process sequences of input.With the biLSTM, the support set is viewed as a sequence of vectors. A bidirectional LSTM is used, instead of just an LSTM, in order to reduce dependence on the sequence order.This improves model performance because the support set has no natural order.However, not all dependence on sequence order is removed with the biLSTM.
The attLSTM constructs an order-independent full contextual embedded vector of the query structure.The full details of the attLSTM will not be discussed here, beyond saying that both the biLSTM and attLSTM are network elements which interpret some set of structures as a sequence of pre-contextualized vectors, and converts a sequence into a single full context embedded vector.One full context embedded vector is produced for the support set of structures, and one is produced for the query structure.
A further improvement has been made to the one-shot model described here.As mentioned, the biLSTM does not produce an entirely order-independent full context embedding for each pre-contextualized vector, corresponding to a support structure.As mentioned, the support set does not contain any natural order to it, so any sequence order dependence present in the full context embedded support vector is an unwanted artifact, and will lead to reduced model performance.There is another problem, which is that, in the way they have been defined, the full context embedded vectors of the support structures depend only on the pre-contextualized support vectors, and not on the pre-contextualized query vector.On the other hand, the full context embedded vector of the query structure depends on both its own pre-contextualized vector, and the pre-contextualized vectors of the support set.This asymmetry indicates that some additional information is not accounted for in the model, and that performance could be improved if this asymmetry could be removed, and if the order dependence of the full context embedded support vectors could also be removed.
To address this problem, a new LSTM model was developed by Altae-Tran, et al., called the Iteratively Refined LSTM (IterRefLSTM).The full details of how the IterRefLSTM model operates is beyond the scope of this discussion.A full explanation can be found in Altae-Tran, et al.Put briefly, the full contextual embedded vectors of the support and query structures are co-evolved, in an iterative process, which uses an attLSTM element, and results in removal of order-dependence in the full contextual embedding for the support, as well removal of the asymmetry in dependency between the full context embedded vectors of the support and query structures.
A brief summary of the one-shot network architectures discussed is presented in Table 2.
Architecture | Description |
---|---|
Siamese Networks | score comparison, dissimilarity measure |
Attention LSTM (attLSTM) | better extraction of prior data, contains order-dependence of input data |
Iterative Refinement LSTMs (IterRefLSTM) | similar to attLSTM, but removes all order dependence of data by iteratively evolving the query and support embeddings simultaneously in an iterative loop |
Table 2. One-shot networks used for investigating low-population biological assay data sets.
Computed Results of One-Shot Performance Metric is Compared to Published Values
A comparison of independently computed values is made here with published values from Altae-Tran, et al. [1].Quantitative results for classification tasks associated with the Tox21, SIDER, and MUV datasets were obtained by evaluating the the area under the receiver operating characteristic curve (read more on AUROC).For datasets having more than one task, the median of the performance metric over all tasks in the held-out data sets is reported.A k-fold cross-validation† was then done, with k=4
.The mean of performances across all cross-validations was then taken, and reported as the performance measure for the data set.A discussion of the standard deviation is given further below.
Since the tasks for Tox21, SIDER, and MUV are all classification tasks for binary assay data, with positive and negative results from a clinical trial, for example, the performance values, as mentioned, are reported with the AUROC metric.With AUROC, a value of 0.5
indicates no predictive power, while a result of 1.0
indicates that every outcome in the held out data set has been predicted correctly [Kennis Research, 9]. A value less than 0.5
can be interpreted as a value of 1.0
minus the metric value.This operation corresponds to inverting the model, where True is now False, and vice versa.This way, a metric value between 0.5
and 1.0
can always be realized. Each data set performance is reported with a standard deviation, containing dependence on the dispersion of metric values across classifications tasks, and then k
cross-validations.
Our computed values match well with those published by Altae-Tran, et al. [1], and essentially confirm their published performance metric values, from their ACS Central Science publication.The first and second columns in Table 3 show classification performance for GC tasks, and RF, respectively, as computed by Altae-Tran, et al.Single task GC and RF results are presented as a baseline of comparison to one-shot methods.
† The use of k-fold cross validation improves the estimated predicted performance of the model, as it would perform if trained on all of the data, and not just a training subset, with a portion reserved testing.Since we cannot directly measure the performance of a model trained on the full data set (since no testing data would remain), the k-fold cross validation is used to provide a best guess of a performance estimate we cannot see (until final deployment), where a deployed network would be trained on all of the data.
Tox21 | SIDER | MUV | |
---|---|---|---|
Random Forests‡,⁑ | 0.539 ± 0.049 | 0.557 ± 0.059 | 0.751 ± 0.062Ω |
Graphical Convolution‡,⁑ | 0.625 ± 0.036 | 0.482 ± 0.038 | 0.583 ± 0.061 |
Siamese Networks | 0.783 ± 0.009 | 0.660 ± 0.088 | 0.500 ± 0.043 |
AttLSTM | 0.759 ± 0.007 | 0.607 ± 0.080 | 0.500 ± 0.058 |
IterRefLSTM | 0.807 ± 0.003Ω | 0.751 ± 0.002Ω | 0.533 ± 0.051 |
Table 3. AUROC performance metric values for each one-shot method, plus the random forests (RF), and graphical convolution (GC) methods.Metric values were measured across Tox21, SIDER, and MUV test data sets, using a trained modelΦ.Randomness arises from using a trained model to evaluate the AUROC metric on a test set.First a support setΨ, S, of 20 data points is chosen from the set of data points for a test task.The metric is then evaluated over the remaining points in a test task data set.This process is repeated 20 times for every test task in the data set. The mean and standard deviation for all AUROC measures generated in this way are computed.
Finally, for each data set (Tox21, SIDER, and MUV), the reported performance result is actually the median performance value across all test tasks for a data set.This indirectly implies that the individual metric performances on individual tasks is unimportant, and that they more or less tend to all do well or poorly together, without too much variance across tasks.However, a median measure can mask outliers, where performance on one task might be very bad.If outliers can be removed for rational reasons, then using the median across task performance can be an effective way of removing the influence of outliers.
‡ The performance measures for RF and GC were computed with one-fold cross validation (i.e. no cross-validation).This is because the RF and GC scripts available with our current version of DeepChem (July, 2017), are written for performing only one-fold validation with these models.
⁑ The variances of k-fold cross validation performance estimates were determined from pooling all performance values, and then finding the median variance of the entire pool.More complex techniques exist for estimating the variance from a cross-validated set, and the reader is invited to investigate other methods [Nadeau, et al.].
Ω This performance measure by IterRefLSTM on the Tox21 data set is the only performance which rates rates as good.IterRefLSTM performance on the SIDER dataset performs fairly, while RF on MUV, rates as only fair.
Φ Since network inference (predicting outcomes) can be done much faster than network training, due to the computationally expensive backprogragation algorithm, only a batch, B, of data points, and not the entire training data, excluding support data, are selected to train.A support set, S, of 20 data points, along with a batch of queries, B, of 128 data points, is selected for each training set task, in each of the the held-out training sets, for a given episode of training.
A number of training episodes equal to 2000 * ntrain
is performed, with one step of minimization performed by the ADAM optimizer per episode[11]. ntrain
is the number of test tasks in a test set.After the total number of training episodes has been computed, an intermediate information structure for the the attLSTM, and IterRefLSTM models, called the embedding vector set, described earlier, is produced.It should be noted that the same size of support set, S, is also used during model testing on the held out testing tasks.
Ψ Every support set, S, whether selected during training or testing, was chosen so that it contained 10 positive and 10 negative samples for the task in question.In the full study done in [1], however, variations on the number of positive and negatives samples are explored for the support set, S.Investigators found that by sampling more data points in S, rather than increasing the number of backpropagation iterations, better model performance resulted.
It should be noted that, for a support set of 10 positive, and 10 negative assay results, our computed results for the Siamese method on MUV do not show any predictive performance.The results published by Altaei-Tran, however, indicate marginally predictive, but poor predictability, with an AUROC metric value of 0.601 ± 0.041.
Our metric was computed several times on both a Tesla P100 16GB GPU, and a Tesla M40 GPU, but we found, with this particular support, the Siamese model has no predictive power, with a metric value of 0.500 ± 0.043 (see Table 3). Our other computed results for the AttLSTM and IterRefLSTM concur with published results, which show that neither one-shot learning method has predictive power on MUV data, with a support set containing 10 positive, and 10 negative assay results.
The Iterative Refinement LSTM shows a narrower dispersion of scores than other one-shot Learning models.This result agrees with published standard deviation values for LSTM in Altae-Tran, et al. [1].
Speedups factors are determined by comparing runtimes on the NVIDIA Tesla P100 GPU, to runtimes on the Tesla M40 GPU, and are presented in Tables 4 and 5.Speedup factors are found to be not as pronounced for one-shot methods, and an explanation of the speedup results is presented.The approach for training and testing one-shot methods is described, as they involve some extra considerations which do not apply to graphical convolution.
Tesla P100 runtimes | Tesla M40 runtimes | |||||
---|---|---|---|---|---|---|
Tox21 | SIDER | MUV | Tox21 | SIDER | MUV | |
Random Forests† | ||||||
Graphical Convolution | 38 | 79 | 64 | 41 | 100 | 720 |
Siamese | 857 | 2,180 | 1,464 | 956 | 2,407 | 1,617 |
AttLSTM | 933 | 2,405 | 1,591 | 1,041 | 2,581 | 1,725 |
IterRefLSTM | 1,006 | 2,511 | 1,680 | 1,101 | 2,721 | 1,834 |
Table 4. Runtimes for each one-shot model on the NVIDIA Tesla M40 and Tesla P100 16GB PCIe GPU.All runtimes are in seconds.
† RF are run entirely on CPU, and reflect CPU runtimes. Their run times are shown with strikethrough, to indicate that their values are not be considered for determining GPU speedup factors.
A quick inspection of the results in Table 4 shows that the one-shot methods perform better on the Tox21 and SIDER data sets, but not on the MUV data.A reason for poor performance of one-shot methods in MUV data is proposed below.
Limitations of One-Shot Networks
Compared to previous methods, one-shot networks demonstrate extraction of more information from the prior (support) data than RF or GC, but with a limitation.One-shot methods are only successful when data in the held out testing set is sufficiently similar to data seen during training.Networks trained using one-shot methods do not perform well when trying to classify data that is too dissimilar from the sample data used for training.In the context of drug discovery, this problem is encountered when trying to apply one-shot learning to the Maximum Unbiased Validation (MUV) dataset, for example [10].Benchmark results show that all three one-shot learning methods explored here do little better than pure chance when making classification predictions with MUV data (see Table 3).
The MUV dataset contains around 93,000 compounds, and represents a diverse collection of molecular scaffolds, compared to Tox21 and SIDER.One-shot methods do not perform as well on this data set, probably because there is less structural similarity between the elements of the MUV dataset, compared to Tox21 and SIDER.One-shot networks require some amount structural similarity, within the data set, in order to extrapolate from limited data, and correctly classify new, but similar, compounds.
A metric of self-similarity within a data set could be computed as a data set size-independent, extensive measure, where every element is compared to every other measurement, and some attention measure is evaluated, such a cosine distance.The attention measure can be summed through all unique comparisons, and then be normalized, by diving by the number of unique comparisons between N elements in the set to all other elements in the set.
Tox21 | SIDER | MUV | |
---|---|---|---|
GC | 1.079 | 1.266 | 11.25α |
Siamese | 1.116λ | 1.104 | 1.105 |
AttLSTM | 1.116λ | 1.116 | 1.084 |
IterRefLSTM | 1.094λ | 1.084 | 1.092 |
Table 5. Speedup Factors, comparing the Tesla P100 16GB GPU to the Tesla M40. All speedups are Tesla P100 runtimes divided by Tesla M40 runtimes.
α The greatest speedup is observed with GC on the MUV data set (Table 5).
GC also exhibits the most precipitous drop in performance, as it transitions to one-shot models. Table 4 indicates that the GC model performs better across all data sets, compared to one-shot methods.This is not surprising, because the purely graphical model is more susceptible to GPU acceleration.However, it is crucial to note that GC models perform worse than one-shot models on Tox21, SIDER, but not MUV.On MUV, GC has nearly no predictive ability (Table 3), compared the one-shot models which have absolutely no predictability with MUV.
λ The one-shot newtorks, while providing substantial improvement in performance, do not seem to show a significant speedup, observing the values for the data sets, in the rows for Siamese, attLTSM, or IterRefLSTM.The nearly absent-speedup could arise from high GPU-system memory transfers.Note, however, that although small, there is a slight but consistent improvement is speedup for the one-shot networks for the Tox21 set.The Tox21 data set may therefore require fewer transfers to system memory.A general observation of the element flatline in speedup for one-shot methods may be from the LSTM elements.
Generally, deep convolutional network models, such as GC, or models which benefit from having a large data set containing structurally diverse groups, such as RF and GC, perform better on the MUV data.RF, for example, shows the best performance, even if very poor.Deep networks have demonstrated that, provided enough layers, they have the information-holding capacity required in order to learn and retain representations for the MUV data.Their information-holding capacity is what enables them to classify between the large number of structurally diverse classes in MUV.It may be the case that the hyper parameters for the graphical convolutional network at not set such that the GC model would yield a poor to fair level of performance on MUV.In their paper, Altae-Tran stated that hyperparameters for the convolutional networks were not optimized, and that there may be an opportunity to improve performance there [1].
Remarks on Neural Network Information Structure, and How One-Shot Networks are Different
All neural networks require training data in order to develop structure under training pressure.Feature complexity, in image classification networks, becomes stratified, under training pressure, through the network’s layers.The lowest layers emerge as edge detectors, with successive layers building upon features from previous layers.The second layer, for example, can build corner detectors, or curved edge detectors, by detecting combinations of simpler edges.Through a buildup of feature complexity, eventually, higher layers can emerge which can detect complex, high-level features such as faces.The combinatorial size of the detectable feature space grows with the number of viable filters (kernels) connecting each layer to the preceding layer.With Natural Language Processing (NLP) networks, layer complexity progresses from sentence features, to paragraphs, then chapters, and finally whole book vector representations, which consist of succinct thematic summaries of written works.
To reiterate, all networks require information structure, acquired under training pressure, to develop some inner representation, or “belief” about data.Deep networks allow for more diverse structures to be learned, compared to one-shot networks, which are limited in their ability to learn diverse representations.One-shot structural features are designed to improve extraction of information for support data, in order to learning a representation which can be used to extrapolate from a smaller group of similar classes.One-shot methods do not perform as well with MUV, compared to RF, for the reason that they are not designed to produce a useful network from a data set having the level of dissimilarity between molecular scaffolds between elements, such as with MUV.
Transfer Learning with One-Shot Learning Network Architecture
A network trained on the Tox21 data set was evaluated on the SIDER data set.The results, given by the performance metric values, shown in Table 6, indicate that the network trained on Tox21 has nearly no predictive power on the SIDER data.This indicates that the performance does not generalize well to new chemical scaffolds, which supports the explanation for why one-shot methods do poorly at predicting the results for the MUV dataset.
Siamese | attnLSTM | IterRefLSTM | |
---|---|---|---|
To SIDER from Tox21 | 0.505 | 0.502 | 0.504 |
Table 6. Transfer Learning to SIDER from Tox21. These results agree with the performance metric values reported for transfer learning in [1], and support the conclusion that transfer learning between data sets will result in no predictive capability, unless the data sets are significantly similar.
Conclusion
For binary classification tasks associated with small population data sources, one-shot learning methods may provide significantly better results compared to baseline performances of graphical convolution and random forests.The results show that the performance of one shot learning methods may depend on the diversity of molecular scaffolds in a data set.With MUV, for example, one shot methods did not extrapolate well to unseen molecular scaffolds.Also, the failure of transfer learning from the Tox21 network, to correctly predict SIDER assay outcomes, also indicates that data set training may not be easily generalized with one shot networks.
The Iterative Refinement LSTM method developed in [1] demonstrates that LSTMs can generalize to similar experimental assays which are not identical to assays in the data set, but which have some common relation.
References
1.) Altae-Tran, Han, Ramsundar, Bharath, Pappu, Aneesh S., and Pande, Vijay”Low Data Drug Discovery with One-Shot Learning.” ACS central science 3.4 (2017): 283-293.
2.) Wu, Zhenqin, et al. “MoleculeNet: A Benchmark for Molecular Machine Learning.” arXiv preprint arXiv:1703.00564 (2017).
3.) Hariharan, Bharath, and Ross Girshick. “Low-shot visual object recognition.” arXiv preprint arXiv:1606.02819 (2016).
4.) Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. “Siamese neural networks for one-shot image recognition.” ICML Deep Learning Workshop. Vol. 2. 2015.
5.) Vinyals, Oriol, et al. “Matching networks for one shot learning.” Advances in Neural Information Processing Systems. 2016.
6.) Wang, Peilu, et al. “A unified tagging solution: Bidirectional LSTM recurrent neural network with word embedding.” arXiv preprint arXiv:1511.00215 (2015).
7.) Duvenaud, David K., et al. “Convolutional networks on graphs for learning molecular fingerprints.” Advances in neural information processing systems. 2015.
8.) Lusci, Alessandro, Gianluca Pollastri, and Pierre Baldi. “Deep architectures and deep learning in chemoinformatics: the prediction of aqueous solubility for drug-like molecules.” Journal of chemical information and modeling 53.7 (2013): 1563-1575.
9. Receiver Operating Curves Applet, Kennis Research, 2016.
10. Maximum Unbiased Validation Chemical Data Set
11. Kingma, D. and Ba, J. Adam: a Method for Stochastic Optimization, arXiv preprint: arxiv.org/pdf/1412.6980v8.pdf.
12. University of Nebraska Medical Center online information resource, AUROC
13. Inference for the Generalization of Error