Skip to content

Commit eba4968

Browse files
authored
fix perf for api changing (#177)
1 parent 0568130 commit eba4968

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

benchmark/perf.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def main():
1212

1313
parser = argparse.ArgumentParser(description='Evoformer Standalone Perf Benchmark')
1414
parser.add_argument("--dap-size", default=1, type=int, help='batch size')
15-
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
1615
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of MSA')
1716
parser.add_argument('--res-length',
1817
default=256,
@@ -85,7 +84,9 @@ def forward(self, node, pair, node_mask, pair_mask):
8584
if args.openfold:
8685
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
8786
else:
88-
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
87+
first_block = idx == 0
88+
last_block = idx == args.layers - 1
89+
attn_layers.append(Evoformer(c_m=args.cm, c_z=args.cz, first_block=first_block, last_block=last_block))
8990
attn_layers[idx].cuda()
9091
attn_layers[idx].to(dtype=precision)
9192

@@ -97,22 +98,23 @@ def forward(self, node, pair, node_mask, pair_mask):
9798
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
9899
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
99100

100-
inputs_node = torch.randn(args.batch_size,
101-
args.msa_length // args.dap_size,
101+
batch_size = 1
102+
inputs_node = torch.randn(batch_size,
103+
args.msa_length,
102104
args.res_length,
103105
args.cm,
104106
dtype=precision,
105107
device=torch.device("cuda")).requires_grad_(True)
106-
inputs_pair = torch.randn(args.batch_size,
107-
args.res_length // args.dap_size,
108+
inputs_pair = torch.randn(batch_size,
109+
args.res_length,
108110
args.res_length,
109111
args.cz,
110112
dtype=precision,
111113
device=torch.device("cuda")).requires_grad_(True)
112-
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
114+
node_mask = torch.ones((batch_size, args.msa_length, args.res_length),
113115
dtype=precision,
114116
device=torch.device("cuda")).requires_grad_(False)
115-
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
117+
pair_mask = torch.ones((batch_size, args.res_length, args.res_length),
116118
dtype=precision,
117119
device=torch.device("cuda")).requires_grad_(False)
118120
grads_node = torch.randn_like(inputs_pair)
@@ -129,6 +131,13 @@ def forward(self, node, pair, node_mask, pair_mask):
129131
with_stack=False)
130132
prof.start()
131133

134+
if not args.openfold:
135+
inputs_node = inputs_node.squeeze(0)
136+
inputs_pair = inputs_pair.squeeze(0)
137+
node_mask = node_mask.squeeze(0)
138+
pair_mask = pair_mask.squeeze(0)
139+
grads_node = grads_node.squeeze(0)
140+
132141
for trial in range(0, args.trials + args.warmup_trials):
133142
layer_inputs = inputs_node, inputs_pair
134143
evt_idx = trial - args.warmup_trials
@@ -168,7 +177,7 @@ def forward(self, node, pair, node_mask, pair_mask):
168177
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
169178

170179
print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
171-
args.batch_size, args.msa_length, args.res_length, \
180+
batch_size, args.msa_length, args.res_length, \
172181
args.cm, args.cz, \
173182
elapsed_time_fwd / ( args.trials * args.layers ), \
174183
elapsed_time_bwd / ( args.trials * args.layers )))

0 commit comments

Comments
 (0)