Skip to content

Commit f4b9bdf

Browse files
committed
Merge: [SE3Transformer/DGLPyT] Benchmarking fixes and tweaks
2 parents 7834973 + 2517f61 commit f4b9bdf

4 files changed

Lines changed: 22 additions & 13 deletions

File tree

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def prepare_data(self):
9999
def _collate(self, samples):
100100
graphs, y, *bases = map(list, zip(*samples))
101101
batched_graph = dgl.batch(graphs)
102-
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
102+
edge_feats = {'0': batched_graph.edata['edge_attr'][:, :self.EDGE_FEATURE_DIM, None]}
103103
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
104104
# get node features
105-
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
105+
node_feats = {'0': batched_graph.ndata['attr'][:, :self.NODE_FEATURE_DIM, None]}
106106
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
107107

108108
if bases:

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(self,
127127
fiber_edge=fiber_edge,
128128
self_interaction=True,
129129
use_layer_norm=use_layer_norm,
130-
max_degree=self.max_degree))
130+
max_degree=self.max_degree,
131+
fuse_level=self.fuse_level,
132+
low_memory=low_memory))
131133
self.graph_modules = Sequential(*graph_modules)
132134

133135
if pooling is not None:

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/inference.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,23 @@ def evaluate(model: nn.Module,
116116
torch.set_float32_matmul_precision('high')
117117

118118
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
119-
evaluate(model,
120-
test_dataloader,
121-
callbacks,
122-
args)
119+
if not args.benchmark:
120+
evaluate(model,
121+
test_dataloader,
122+
callbacks,
123+
args)
123124

124-
for callback in callbacks:
125-
callback.on_validation_end()
125+
for callback in callbacks:
126+
callback.on_validation_end()
126127

127-
if args.benchmark:
128+
else:
128129
world_size = dist.get_world_size() if dist.is_initialized() else 1
129-
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
130-
for _ in range(6):
130+
callbacks = [PerformanceCallback(
131+
logger, args.batch_size * world_size,
132+
warmup_epochs=1 if args.epochs > 1 else 0,
133+
mode='inference'
134+
)]
135+
for _ in range(args.epochs):
131136
evaluate(model,
132137
test_dataloader,
133138
callbacks,

DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def print_parameters_count(model):
221221
if args.benchmark:
222222
logging.info('Running benchmark mode')
223223
world_size = dist.get_world_size() if dist.is_initialized() else 1
224-
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
224+
callbacks = [PerformanceCallback(
225+
logger, args.batch_size * world_size, warmup_epochs=1 if args.epochs > 1 else 0
226+
)]
225227
else:
226228
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
227229
QM9LRSchedulerCallback(logger, epochs=args.epochs)]

0 commit comments

Comments
 (0)