Skip to content

Commit 4dc719f

Browse files
optimize code for bwd performance and refine code. (#145)
* optimize code for bwd performance and refine code. * refine README to add test SW version
1 parent 95150c3 commit 4dc719f

File tree

6 files changed

+49
-9
lines changed

6 files changed

+49
-9
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,12 @@ python inference.py target.fasta data/pdb_mmcif/mmcif_files/ \
210210

211211
### Inference or Training on Intel Habana
212212

213-
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments.
213+
To run AlphaFold inference or training on Intel Habana, you can follow the instructions in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/) to set up your environment on Amazon EC2 DL1 instances or on-premise environments, and please use SynapseAI R1.7.1 to test as it was verified internally.
214214

215215
Once you have prepared your dataset and installed fastfold, you can use the following scripts:
216216

217217
```shell
218+
cd fastfold/habana/fastnn/custom_op/; python setup.py build (this is for Gaudi, for Gaudi2 please use setup2.py) ; cd -
218219
bash habana/inference.sh
219220
bash habana/train.sh
220221
```
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
22
gather, reduce, row_to_col, scatter)
3-
from .core import init_dist
3+
from .core import init_dist, get_data_parallel_world_size
44

55
__all__ = [
6-
'init_dist', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
6+
'init_dist', 'get_data_parallel_world_size', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
77
'col_to_row', 'row_to_col', 'All_to_All'
88
]

fastfold/utils/rigid_utils.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020
import torch
21-
21+
import fastfold.habana as habana
2222

2323
def rot_matmul(
2424
a: torch.Tensor,
@@ -34,6 +34,19 @@ def rot_matmul(
3434
Returns:
3535
The product ab
3636
"""
37+
if habana.is_habana():
38+
if len(a.shape) == 4 and a.shape[1] == 1:
39+
aa = a.permute(0, 1, 3, 2)
40+
bb = b.permute(0, 1, 3, 2)
41+
cc = bb @ aa
42+
cc = cc.permute(0, 1, 3, 2)
43+
return cc
44+
elif len(a.shape) == 4 and a.shape[1] != 1:
45+
pass
46+
else:
47+
cc = a @ b
48+
return cc
49+
3750
row_1 = torch.stack(
3851
[
3952
a[..., 0, 0] * b[..., 0, 0]
@@ -94,6 +107,20 @@ def rot_vec_mul(
94107
Returns:
95108
[*, 3] rotated coordinates
96109
"""
110+
if habana.is_habana():
111+
cont = True
112+
if len(t.shape) == 4 and t.shape[1] == 1:
113+
cont = False
114+
elif len(t.shape) == 3 and t.shape[0] != r.shape[0] and t.shape[0] == 1:
115+
cont = False
116+
117+
if cont:
118+
tt = t.unsqueeze(-2)
119+
rr = r.transpose(-2, -1)
120+
cc = tt @ rr
121+
cc = cc.squeeze(-2)
122+
return cc
123+
97124
x = t[..., 0]
98125
y = t[..., 1]
99126
z = t[..., 2]

habana/inference.sh

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
2+
export PYTHONPATH=./:$PYTHONPATH
3+
14
# add '--gpus [N]' to use N gpus for inference
25
# add '--enable_workflow' to use parallel workflow for data processing
36
# add '--use_precomputed_alignments [path_to_alignments]' to use precomputed msa

habana/train.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import fastfold.habana as habana
1111
from fastfold.config import model_config
1212
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
13-
from fastfold.habana.distributed import init_dist
13+
from fastfold.habana.distributed import init_dist, get_data_parallel_world_size
1414
from fastfold.habana.inject_habana import inject_habana
1515
from fastfold.model.hub import AlphaFold, AlphaFoldLoss, AlphaFoldLRScheduler
1616
from fastfold.utils.tensor_utils import tensor_tree_map
@@ -156,7 +156,8 @@ def main():
156156
model = inject_habana(model)
157157

158158
model = model.to(device="hpu")
159-
model = DDP(model)
159+
if get_data_parallel_world_size() > 1:
160+
model = DDP(model, gradient_as_bucket_view=True, bucket_cap_mb=400)
160161

161162
train_dataset, test_dataset = SetupTrainDataset(
162163
config=config.data,
@@ -201,27 +202,32 @@ def main():
201202
isVerbose=args.hmp_verbose)
202203
print("========= HMP ENABLED!!")
203204

205+
idx = 0
204206
for epoch in range(200):
205207
model.train()
206208
train_dataloader = tqdm(train_dataloader)
207209
for batch in train_dataloader:
208210
perf = hpu_perf("train step")
209-
batch = {k: torch.as_tensor(v).to(device="hpu") for k, v in batch.items()}
211+
batch = {k: torch.as_tensor(v).to(device="hpu", non_blocking=True) for k, v in batch.items()}
210212
optimizer.zero_grad()
213+
perf.checknow("prepare input and zero grad")
211214
output = model(batch)
212215
perf.checknow("forward")
213216

214217
batch = tensor_tree_map(lambda t: t[..., -1], batch)
218+
perf.checknow("prepare loss input")
215219
loss, loss_breakdown = criterion(output, batch, _return_breakdown=True)
216220
perf.checknow("loss")
217221

218222
loss.backward()
219-
train_dataloader.set_postfix(loss=float(loss))
223+
if idx % 10 == 0:
224+
train_dataloader.set_postfix(loss=float(loss))
220225
perf.checknow("backward")
221226

222227
with hmp.disable_casts():
223228
optimizer.step()
224229
perf.checknow("optimizer")
230+
idx += 1
225231

226232
lr_scheduler.step()
227233

habana/train.sh

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
DATA_DIR=/mnt/usb/training-demo
1+
export GC_KERNEL_PATH=./fastfold/habana/fastnn/custom_op/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
2+
export PYTHONPATH=./:$PYTHONPATH
3+
4+
DATA_DIR=../FastFold-dataset/train
25

36
hpus_per_node=1
47

0 commit comments

Comments
 (0)