@@ -12,7 +12,6 @@ def main():
12
12
13
13
parser = argparse .ArgumentParser (description = 'Evoformer Standalone Perf Benchmark' )
14
14
parser .add_argument ("--dap-size" , default = 1 , type = int , help = 'batch size' )
15
- parser .add_argument ('--batch-size' , default = 1 , type = int , help = 'batch size' )
16
15
parser .add_argument ('--msa-length' , default = 132 , type = int , help = 'Sequence Length of MSA' )
17
16
parser .add_argument ('--res-length' ,
18
17
default = 256 ,
@@ -85,7 +84,9 @@ def forward(self, node, pair, node_mask, pair_mask):
85
84
if args .openfold :
86
85
attn_layers .append (OpenFoldEvoformer (d_node = args .cm , d_pair = args .cz ))
87
86
else :
88
- attn_layers .append (Evoformer (d_node = args .cm , d_pair = args .cz ))
87
+ first_block = idx == 0
88
+ last_block = idx == args .layers - 1
89
+ attn_layers .append (Evoformer (c_m = args .cm , c_z = args .cz , first_block = first_block , last_block = last_block ))
89
90
attn_layers [idx ].cuda ()
90
91
attn_layers [idx ].to (dtype = precision )
91
92
@@ -97,22 +98,23 @@ def forward(self, node, pair, node_mask, pair_mask):
97
98
start_evt_bwd .append (torch .cuda .Event (enable_timing = True ))
98
99
stop_evt_bwd .append (torch .cuda .Event (enable_timing = True ))
99
100
100
- inputs_node = torch .randn (args .batch_size ,
101
- args .msa_length // args .dap_size ,
101
+ batch_size = 1
102
+ inputs_node = torch .randn (batch_size ,
103
+ args .msa_length ,
102
104
args .res_length ,
103
105
args .cm ,
104
106
dtype = precision ,
105
107
device = torch .device ("cuda" )).requires_grad_ (True )
106
- inputs_pair = torch .randn (args . batch_size ,
107
- args .res_length // args . dap_size ,
108
+ inputs_pair = torch .randn (batch_size ,
109
+ args .res_length ,
108
110
args .res_length ,
109
111
args .cz ,
110
112
dtype = precision ,
111
113
device = torch .device ("cuda" )).requires_grad_ (True )
112
- node_mask = torch .ones ((args . batch_size , args .msa_length , args .res_length ),
114
+ node_mask = torch .ones ((batch_size , args .msa_length , args .res_length ),
113
115
dtype = precision ,
114
116
device = torch .device ("cuda" )).requires_grad_ (False )
115
- pair_mask = torch .ones ((args . batch_size , args .res_length , args .res_length ),
117
+ pair_mask = torch .ones ((batch_size , args .res_length , args .res_length ),
116
118
dtype = precision ,
117
119
device = torch .device ("cuda" )).requires_grad_ (False )
118
120
grads_node = torch .randn_like (inputs_pair )
@@ -129,6 +131,13 @@ def forward(self, node, pair, node_mask, pair_mask):
129
131
with_stack = False )
130
132
prof .start ()
131
133
134
+ if not args .openfold :
135
+ inputs_node = inputs_node .squeeze (0 )
136
+ inputs_pair = inputs_pair .squeeze (0 )
137
+ node_mask = node_mask .squeeze (0 )
138
+ pair_mask = pair_mask .squeeze (0 )
139
+ grads_node = grads_node .squeeze (0 )
140
+
132
141
for trial in range (0 , args .trials + args .warmup_trials ):
133
142
layer_inputs = inputs_node , inputs_pair
134
143
evt_idx = trial - args .warmup_trials
@@ -168,7 +177,7 @@ def forward(self, node, pair, node_mask, pair_mask):
168
177
elapsed_time_bwd += start_evt_bwd [evt_idx ].elapsed_time (stop_evt_bwd [evt_idx ])
169
178
170
179
print ("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms" .format (
171
- args . batch_size , args .msa_length , args .res_length , \
180
+ batch_size , args .msa_length , args .res_length , \
172
181
args .cm , args .cz , \
173
182
elapsed_time_fwd / ( args .trials * args .layers ), \
174
183
elapsed_time_bwd / ( args .trials * args .layers )))
0 commit comments