Skip to content

Commit 09ffe18

Browse files
committed
add papers implemented
1 parent 186aef2 commit 09ffe18

File tree

14 files changed

+77
-23
lines changed

14 files changed

+77
-23
lines changed

Diff for: doc/source/api/layers.rst

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ MutualInformation
2121
.. autoclass:: MutualInformation
2222
:members:
2323

24+
SinusoidalPositionEmbedding
25+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
.. autoclass:: SinusoidalPositionEmbedding
27+
:members:
28+
2429
PairNorm
2530
^^^^^^^^
2631
.. autoclass:: PairNorm

Diff for: doc/source/bibliography.rst

+7
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@
5757
.. Retrosynthesis
5858
.. _G2Gs: https://arxiv.org/pdf/2003.12725.pdf
5959

60+
.. Protein Representation Learning
61+
.. _TAPE: https://proceedings.neurips.cc/paper/2019/file/37f65c068b7723cd7809ee2d31d7861c-Paper.pdf
62+
.. _ProteinCNN: https://arxiv.org/pdf/2011.03443.pdf
63+
.. _ESM: https://www.biorxiv.org/content/10.1101/622803v1.full.pdf
64+
.. _GearNet: https://arxiv.org/pdf/2203.06125.pdf
65+
66+
.. Knowledge Graph Reasoning
6067
.. _TransE: http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf
6168
.. _DistMult: https://arxiv.org/pdf/1412.6575.pdf
6269
.. _ComplEx: http://proceedings.mlr.press/v48/trouillon16.pdf

Diff for: doc/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
#
8686
html_theme = "furo"
8787

88-
html_logo = "../../asset/logo.svg"
88+
html_logo = "../../asset/torchdrug_logo_full.svg"
8989

9090
# Theme options are theme-specific and customize the look and feel of a theme
9191
# further. For a list of options available for each theme, see the

Diff for: doc/source/paper.rst

+35
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,41 @@ Retrosynthesis
147147
:class:`SynthonCompletion <torchdrug.tasks.SynthonCompletion>`,
148148
:class:`Retrosynthesis <torchdrug.tasks.Retrosynthesis>`
149149

150+
Protein Representation Learning
151+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
152+
153+
1. `Evaluating Protein Transfer Learning with TAPE <TAPE_>`_
154+
155+
Roshan Rao, Nicholas Bhattacharya, Neil Thomas, Yan Duan, Xi Chen, John Canny, Pieter Abbeel, Yun S Song. NeurIPS 2019.
156+
157+
:class:`SinusoidalPositionEmbedding <torchdrug.layers.SinusoidalPositionEmbedding>`
158+
:class:`SelfAttentionBlock <torchdrug.layers.SelfAttentionBlock>`
159+
:class:`ProteinResNetBlock <torchdrug.layers.ProteinResNetBlock>`
160+
:class:`ProteinBERTBlock <torchdrug.layers.ProteinBERTBlock>`
161+
:class:`ProteinResNet <torchdrug.models.ProteinResNet>`
162+
:class:`ProteinLSTM <torchdrug.models.ProteinLSTM>`
163+
:class:`ProteinBERT <torchdrug.models.ProteinBERT>`
164+
165+
2. `Is Transfer Learning Necessary for Protein Landscape Prediction? <ProteinCNN_>`_
166+
167+
Amir Shanehsazzadeh, David Belanger, David Dohan. arXiv 2020.
168+
169+
:class:`ProteinCNN <torchdrug.models.ProteinCNN>`
170+
171+
3. `Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences <ESM_>`_
172+
173+
Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, Rob Fergus. PNAS 2021.
174+
175+
:class:`EvolutionaryScaleModeling <torchdrug.models.EvolutionaryScaleModeling>`
176+
177+
4. `Protein Representation Learning by Geometric Structure Pretraining <GearNet_>`_
178+
179+
Zuobai Zhang, Minghao Xu, Arian Jamasb, Vijil Chenthamarakshan, Aurélie Lozano, Payel Das, Jian Tang. arXiv 2022.
180+
181+
:class:`GeometricRelationalGraphConv <torchdrug.layers.GeometricRelationalGraphConv>`
182+
:class:`GeometryAwareRelationalGraphNeuralNetwork <torchdrug.models.GeometryAwareRelationalGraphNeuralNetwork>`
183+
:mod:`torchdrug.layers.geometry`
184+
150185
Knowledge Graph Reasoning
151186
^^^^^^^^^^^^^^^^^^^^^^^^^
152187

Diff for: torchdrug/data/dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,12 @@ def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, n
715715
self.num_samples = num_samples
716716

717717
@utils.copy_args(data.Protein.from_molecule)
718-
def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbose=0, **kwargs):
718+
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
719719
"""
720720
Load the dataset from pdb files.
721721
722722
Parameters:
723723
pdb_files (list of str): pdb file names
724-
sanitize (bool, optional): whether to sanitize the molecule
725724
transform (Callable, optional): protein sequence transformation function
726725
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
727726
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
@@ -744,7 +743,7 @@ def load_pdbs(self, pdb_files, sanitize=True, transform=None, lazy=False, verbos
744743
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
745744
for i, pdb_file in enumerate(pdb_files):
746745
if not lazy or i == 0:
747-
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
746+
mol = Chem.MolFromPDBFile(pdb_file)
748747
if not mol:
749748
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
750749
continue

Diff for: torchdrug/data/protein.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def from_sequence(cls, sequence, atom_feature="default", bond_feature="default",
305305
@classmethod
306306
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
307307
def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", residue_feature="default",
308-
mol_feature=None, kekulize=False, sanitize=False):
308+
mol_feature=None, kekulize=False):
309309
"""
310310
Create a protein from a PDB file.
311311
@@ -319,11 +319,10 @@ def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", resi
319319
Note this only affects the relation in ``edge_list``.
320320
For ``bond_type``, aromatic bonds are always stored explicitly.
321321
By default, aromatic bonds are stored.
322-
sanitize (bool, optional): whether to sanitize the molecule
323322
"""
324323
if not os.path.exists(pdb_file):
325324
raise FileNotFoundError("No such file `%s`" % pdb_file)
326-
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
325+
mol = Chem.MolFromPDBFile(pdb_file)
327326
if mol is None:
328327
raise ValueError("RDKit cannot read PDB file `%s`" % pdb_file)
329328
return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
@@ -1052,7 +1051,7 @@ def from_sequence(cls, sequences, atom_feature="default", bond_feature="default"
10521051
@classmethod
10531052
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
10541053
def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default",
1055-
mol_feature=None, kekulize=False, sanitize=False):
1054+
mol_feature=None, kekulize=False):
10561055
"""
10571056
Create a protein from a list of PDB files.
10581057
@@ -1066,11 +1065,10 @@ def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", res
10661065
Note this only affects the relation in ``edge_list``.
10671066
For ``bond_type``, aromatic bonds are always stored explicitly.
10681067
By default, aromatic bonds are stored.
1069-
sanitize (bool, optional): whether to sanitize the molecule
10701068
"""
10711069
mols = []
10721070
for pdb_file in pdb_files:
1073-
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize)
1071+
mol = Chem.MolFromPDBFile(pdb_file)
10741072
mols.append(mol)
10751073

10761074
return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)

Diff for: torchdrug/layers/common.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def forward(self, input):
7474
class GaussianSmearing(nn.Module):
7575
r"""
7676
Gaussian smearing from
77-
`SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_.
77+
`SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_.``
7878
7979
There are two modes for Gaussian smearing.
8080
@@ -167,7 +167,7 @@ def forward(self, graph, input):
167167
class InstanceNorm(nn.modules.instancenorm._InstanceNorm):
168168
"""
169169
Instance normalization for graphs. This layer follows the definition in
170-
`GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training`.
170+
`GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training`_.
171171
172172
.. _GraphNorm\: A Principled Approach to Accelerating Graph Neural Network Training:
173173
https://arxiv.org/pdf/2009.03294.pdf
@@ -325,13 +325,23 @@ def forward(self, *args, **kwargs):
325325

326326

327327
class SinusoidalPositionEmbedding(nn.Module):
328+
"""
329+
Positional embedding based on sine and cosine functions, proposed in `Attention Is All You Need`_.
330+
331+
.. _Attention Is All You Need:
332+
https://arxiv.org/pdf/1706.03762.pdf
333+
334+
Parameters:
335+
output_dim (int): output dimension
336+
"""
328337

329338
def __init__(self, output_dim):
330339
super(SinusoidalPositionEmbedding, self).__init__()
331340
inverse_frequency = 1 / (10000 ** (torch.arange(0.0, output_dim, 2.0) / output_dim))
332341
self.register_buffer("inverse_frequency", inverse_frequency)
333342

334343
def forward(self, input):
344+
""""""
335345
# input: [B, L, ...]
336346
positions = torch.arange(input.shape[1] - 1, -1, -1.0, dtype=input.dtype, device=input.device)
337347
sinusoidal_input = torch.outer(positions, self.inverse_frequency)

Diff for: torchdrug/layers/geometry/graph.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
4646
def edge_residue_type(self, graph, edge_list):
4747
node_in, node_out, _ = edge_list.t()
4848
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
49-
in_residue_type = graph.edge_residue_type[residue_in]
50-
out_residue_type = graph.edge_residue_type[residue_out]
49+
in_residue_type = graph.residue_type[residue_in]
50+
out_residue_type = graph.residue_type[residue_out]
5151

5252
return torch.cat([
5353
functional.one_hot(in_residue_type, len(data.Protein.residue2id)),
@@ -57,8 +57,8 @@ def edge_residue_type(self, graph, edge_list):
5757
def edge_gearnet(self, graph, edge_list, num_relation):
5858
node_in, node_out, r = edge_list.t()
5959
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
60-
in_residue_type = graph.edge_residue_type[residue_in]
61-
out_residue_type = graph.edge_residue_type[residue_out]
60+
in_residue_type = graph.residue_type[residue_in]
61+
out_residue_type = graph.residue_type[residue_out]
6262
sequential_dist = torch.abs(residue_in - residue_out)
6363
spatial_dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1)
6464

Diff for: torchdrug/models/bert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
6767
dict with ``residue_feature`` and ``graph_feature`` fields:
6868
residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)`
6969
"""
70-
input = graph.edge_residue_type
70+
input = graph.residue_type
7171
size_ext = graph.num_residues
7272
# Prepend BOS
7373
bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.num_residue_type

Diff for: torchdrug/models/esm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
100100
dict with ``residue_feature`` and ``graph_feature`` fields:
101101
residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)`
102102
"""
103-
input = graph.edge_residue_type
103+
input = graph.residue_type
104104
input = self.mapping[input]
105105
size = graph.num_residues
106106
if (size > self.max_input_length).any():

Diff for: torchdrug/models/lstm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
from torch.nn import functional as F
66

7-
from torchdrug import core, layers
7+
from torchdrug import core
88
from torchdrug.layers import functional
99
from torchdrug.core import Registry as R
1010

Diff for: torchdrug/models/physicochemical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
8888
Returns:
8989
dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)`
9090
"""
91-
input = graph.edge_residue_type
91+
input = graph.residue_type
9292

9393
x = self.property[input] # num_residue * 8
9494
x_mean = scatter_mean(x, graph.residue2graph, dim=0, dim_size=graph.batch_size) # batch_size * 8

Diff for: torchdrug/models/statistic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
5858
Returns:
5959
dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)`
6060
"""
61-
input = graph.edge_residue_type
61+
input = graph.residue_type
6262

6363
index = input[:-1] * self.num_residue_type + input[1:]
6464
index = graph.residue2graph[:-1] * self.input_dim + index

Diff for: torchdrug/tasks/pretrain.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,10 @@ def predict_and_target(self, batch, all_loss=None, metric=None):
133133
input = graph.node_feature.float()
134134
input[node_index] = 0
135135
else:
136-
target = graph.edge_residue_type[node_index]
136+
target = graph.residue_type[node_index]
137137
with graph.residue():
138138
graph.residue_feature[node_index] = 0
139-
graph.edge_residue_type[node_index] = 0
139+
graph.residue_type[node_index] = 0
140140
# Generate masked edge features. Any better implementation?
141141
if self.graph_construction_model:
142142
graph = self.graph_construction_model.apply_edge_layer(graph)

0 commit comments

Comments
 (0)