Replies: 2 comments 1 reply
-
@aurooj you can do this with a new utility in TorchVision: https://pytorch.org/vision/stable/feature_extraction.html Try installing the latest PyTorch (1.10) and Torchvision (0.11) which came out this week then: from pprint import pprint
import timm
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
model = timm.create_model('vit_base_patch32_224')
nodes, _ = get_graph_node_names(model, tracer_kwargs={'leaf_modules': [PatchEmbed]})
pprint(nodes) This will print out all the "nodes" that you may extract features from. On doing this, and inspecting import torch
N = 2
# This is the "one line of code" that does what you want
feature_extractor = create_feature_extractor(
model, return_nodes=[f'blocks.{N}.attn.softmax'],
tracer_kwargs={'leaf_modules': [PatchEmbed]})
with torch.no_grad():
out = feature_extractor(torch.zeros(1, 3, 224, 224))
print(out[f'blocks.{N}.attn.softmax'].shape) Note: You need the |
Beta Was this translation helpful? Give feedback.
-
@alexander-soare Thanks a lot for the prompt reply. This is helpful! |
Beta Was this translation helpful? Give feedback.
-
Hi @rwightman, thanks a lot for this amazing work.
I am looking to extract attention scores for
vit_base_patch32_224
model from any specific layer I want.So far, i was able to load pretrained ViT model and load weights from any X layers I want. However, when I do the forward pass on these layers, I only get the hidden states outputs.
Is there a way to also get attention scores for each patch in ViT using timm?
Look forward to hearing from you soon.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions