-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathGLaMM.py
executable file
·311 lines (256 loc) · 14.6 KB
/
GLaMM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import torch
import torch.nn as nn
from typing import List
import torch.nn.functional as F
from model.SAM import build_sam_vit_h
from model.llava.model.language_model.llava_llama import LlavaLlamaForCausalLM, LlavaLlamaModel
def calculate_dice_loss(predictions: torch.Tensor, ground_truth: torch.Tensor, mask_count: float, scale_factor=1000,
epsilon=1e-6):
"""
Calculate the DICE loss, a measure similar to generalized IOU for masks.
"""
predictions = predictions.sigmoid()
predictions = predictions.flatten(1, 2)
ground_truth = ground_truth.flatten(1, 2)
intersection = 2 * (predictions / scale_factor * ground_truth).sum(dim=-1)
union = (predictions / scale_factor).sum(dim=-1) + (ground_truth / scale_factor).sum(dim=-1)
dice_loss = 1 - (intersection + epsilon) / (union + epsilon)
dice_loss = dice_loss.sum() / (mask_count + 1e-8)
return dice_loss
def compute_sigmoid_cross_entropy(predictions: torch.Tensor, targets: torch.Tensor, mask_count: float):
"""
Compute sigmoid cross-entropy loss for binary classification.
"""
loss = F.binary_cross_entropy_with_logits(predictions, targets, reduction="none")
loss = loss.flatten(1, 2).mean(1)
loss = loss.sum() / (mask_count + 1e-8)
return loss
class GLaMMBaseModel:
def __init__(self, config, **kwargs):
super(GLaMMBaseModel, self).__init__(config)
self.config = config
self.vision_pretrained = kwargs.get("vision_pretrained", None)
# Set config attributes if they don't exist
self.config.train_mask_decoder = getattr(
self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False)
)
self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 512))
self.initialize_glamm_model(self.config)
def initialize_glamm_model(self, config):
# Initialize the visual model
self.grounding_encoder = build_sam_vit_h(self.vision_pretrained)
self._configure_grounding_encoder(config)
# Initialize the text projection layer
self._initialize_text_projection_layer()
def _configure_grounding_encoder(self, config):
# Freezing visual model parameters
for param in self.grounding_encoder.parameters():
param.requires_grad = False
# Training mask decoder if specified
if config.train_mask_decoder:
self._train_mask_decoder()
def _train_mask_decoder(self):
self.grounding_encoder.mask_decoder.train()
for param in self.grounding_encoder.mask_decoder.parameters():
param.requires_grad = True
def _initialize_text_projection_layer(self):
in_dim, out_dim = self.config.hidden_size, self.config.out_dim
text_projection_layers = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim),
nn.Dropout(0.0), ]
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)])
self.text_hidden_fcs.train()
self.text_hidden_fcs.train()
class GLaMMModel(GLaMMBaseModel, LlavaLlamaModel):
def __init__(self, config, **kwargs):
super(GLaMMModel, self).__init__(config, **kwargs)
self._configure_model_settings()
def _configure_model_settings(self):
self.config.use_cache = False
self.config.vision_module = self.config.mm_vision_module
self.config.select_feature_type = "patch"
self.config.image_aspect = "square"
self.config.image_grid_points = None
self.config.tune_mlp_adapter = False
self.config.freeze_mlp_adapter = True
self.config.pretrain_mm_mlp_adapter = None
self.config.use_image_patch_token = False
class GLaMMForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, config, **kwargs):
self._set_model_configurations(config, kwargs)
super().__init__(config)
self.model = GLaMMModel(config, **kwargs)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def _set_model_configurations(self, config, kwargs):
config.mm_use_image_start_end = kwargs.pop("use_mm_start_end", True)
config.mm_vision_module = kwargs.get("vision_module", "openai/clip-vit-large-patch14-336")
self._initialize_loss_weights(kwargs)
config.bbox_token_idx = kwargs.get("bbox_token_idx", 1)
config.num_reg_features = kwargs.get("num_level_reg_features", 4)
config.with_region = kwargs.get("with_region", True)
config.bbox_token_idx = kwargs.get("bbox_token_idx", 32002)
self.seg_token_idx = kwargs.pop("seg_token_idx")
def _initialize_loss_weights(self, kwargs):
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
def get_grounding_encoder_embs(self, pixel_values: torch.FloatTensor):
with torch.no_grad():
return torch.cat([self._encode_single_image(img) for img in pixel_values], dim=0)
def _encode_single_image(self, image):
torch.cuda.empty_cache()
return self.model.grounding_encoder.image_encoder(image.unsqueeze(0))
def forward(self, **kwargs):
return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs)
def model_forward(self, global_enc_images: torch.FloatTensor, grounding_enc_images: torch.FloatTensor,
bboxes: torch.FloatTensor, input_ids: torch.LongTensor, labels: torch.LongTensor,
attention_masks: torch.LongTensor, offset: torch.LongTensor, masks_list: List[torch.FloatTensor],
label_list: List[torch.Tensor], resize_list: List[tuple], inference: bool = False, **kwargs, ):
# Handle inference or training paths
if inference:
output_hidden_states = self._inference_path(input_ids, global_enc_images, attention_masks)
else:
output, output_hidden_states = self._training_path(
global_enc_images, bboxes, input_ids, labels, attention_masks, offset
)
if grounding_enc_images is not None:
# Extract grounding encoder image embeddings
image_embeddings = self.get_grounding_encoder_embs(grounding_enc_images)
assert image_embeddings.shape[0] == len(offset) - 1
# Create segmentation token mask
seg_token_mask = self._create_seg_token_mask(input_ids)
# Process hidden states
hidden_states, pred_embeddings = self._process_hidden_states(output_hidden_states, seg_token_mask, offset)
# Generate and post-process masks
pred_masks = self._generate_and_postprocess_masks(
pred_embeddings, image_embeddings, resize_list, label_list
)
if inference:
return {"pred_masks": pred_masks, "gt_masks": masks_list, }
else:
pred_masks = None
# Calculate losses
return self._calculate_losses(pred_masks, masks_list, output)
def _create_seg_token_mask(self, input_ids):
mask = input_ids[:, 1:] == self.seg_token_idx
return torch.cat(
[torch.zeros((mask.shape[0], 575)).bool().cuda(), mask, torch.zeros((mask.shape[0], 1)).bool().cuda()],
dim=1
)
def _inference_path(self, input_ids, global_enc_images, attention_masks):
length = input_ids.shape[0]
global_enc_images_extended = global_enc_images.expand(length, -1, -1, -1).contiguous()
# Process and return inference output
output_hidden_states = []
for i in range(input_ids.shape[0]):
output_i = super().forward(
images=global_enc_images_extended[i:i + 1], attention_mask=attention_masks[i:i + 1],
input_ids=input_ids[i:i + 1], output_hidden_states=True, )
output_hidden_states.append(output_i.hidden_states)
torch.cuda.empty_cache()
output_hidden_states = torch.cat(output_hidden_states, dim=0)
output_hidden_states = [output_hidden_states]
return output_hidden_states
def _training_path(self, global_enc_images, bboxes, input_ids, labels, attention_masks, offset):
global_enc_images = self._prepare_global_enc_image(global_enc_images, offset)
bboxes_list = bboxes
output = super().forward(
images=global_enc_images, attention_mask=attention_masks, input_ids=input_ids, labels=labels,
output_hidden_states=True, bboxes=bboxes_list, )
output_hidden_states = output.hidden_states
return output, output_hidden_states
def _prepare_global_enc_image(self, global_enc_image, offset):
global_enc_image_list = []
for i in range(len(offset) - 1):
start_i, end_i = offset[i], offset[i + 1]
global_enc_image_i = global_enc_image[i].unsqueeze(0).expand(end_i - start_i, -1, -1, -1).contiguous()
global_enc_image_list.append(global_enc_image_i)
return torch.cat(global_enc_image_list, dim=0)
def _process_hidden_states(self, output_hidden_states, seg_token_mask, offset, infer=False):
hidden_states = [self.model.text_hidden_fcs[0](output_hidden_states[-1])]
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1)
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
if not infer:
seg_token_offset = seg_token_offset[offset]
pred_embeddings_list = []
for i in range(len(seg_token_offset) - 1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
pred_embeddings_list.append(pred_embeddings[start_i:end_i])
return hidden_states, pred_embeddings_list
def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list, label_list, infer=False):
pred_masks = []
for i, pred_embedding in enumerate(pred_embeddings):
sparse_embeddings, dense_embeddings = self.model.grounding_encoder.prompt_encoder(
points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1)
)
sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype)
low_res_masks, _ = self.model.grounding_encoder.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.model.grounding_encoder.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
multimask_output=False, )
orig_size = label_list[i].shape if not infer else label_list[i]
# During inference, we have original size list in place of label list
pred_mask = self.model.grounding_encoder.postprocess_masks(
low_res_masks, input_size=resize_list[i], original_size=orig_size, )
pred_masks.append(pred_mask[:, 0])
return pred_masks
def _calculate_losses(self, pred_masks, masks_list, output):
loss_components = self._compute_loss_components(pred_masks, masks_list, output)
return loss_components
def _compute_loss_components(self, pred_masks, masks_list, output):
# Initialize loss components
ce_loss = output.loss * self.ce_loss_weight
mask_bce_loss = torch.tensor(0.0, device=ce_loss.device)
mask_dice_loss = torch.tensor(0.0, device=ce_loss.device)
num_masks = 0
if pred_masks:
# Iterate over batch and compute mask-related losses
for batch_idx, pred_mask in enumerate(pred_masks):
if pred_mask.numel() > 0: # Ensure pred_mask is not empty
gt_mask = masks_list[batch_idx]
# Resize gt_mask to match pred_mask if needed
if gt_mask.shape[0] != pred_mask.shape[0]:
gt_mask = gt_mask[:pred_mask.shape[0]]
assert gt_mask.shape[0] == pred_mask.shape[
0], f"Shape mismatch: gt_mask {gt_mask.shape}, pred_mask {pred_mask.shape}"
# Compute Binary Cross-Entropy Loss
mask_bce_loss += (compute_sigmoid_cross_entropy(pred_mask, gt_mask, mask_count=gt_mask.shape[0]) *
gt_mask.shape[0])
# Compute Dice Loss
mask_dice_loss += (
calculate_dice_loss(pred_mask, gt_mask, mask_count=gt_mask.shape[0]) * gt_mask.shape[0])
num_masks += gt_mask.shape[0]
# Normalize the losses
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
mask_loss = mask_bce_loss + mask_dice_loss
# Aggregate all loss components
total_loss = ce_loss + mask_loss
return {"loss": total_loss, "ce_loss": ce_loss, "mask_bce_loss": mask_bce_loss,
"mask_dice_loss": mask_dice_loss, "mask_loss": mask_loss, }
def evaluate(self, global_enc_images, grounding_enc_images, input_ids, resize_list, orig_sizes, max_tokens_new=32,
bboxes=None, ):
with torch.no_grad():
generation_outputs = self.generate(
images=global_enc_images, input_ids=input_ids, bboxes=bboxes, max_new_tokens=max_tokens_new,
num_beams=1, output_hidden_states=True, return_dict_in_generate=True, )
output_hidden_states = generation_outputs.hidden_states
generated_output_ids = generation_outputs.sequences
seg_token_mask = generated_output_ids[:, 1:] == self.seg_token_idx
# Adjusting for IMAGE_TOKEN_INDEX (assuming single image at start)
seg_token_mask = torch.cat(
[torch.zeros((seg_token_mask.shape[0], 575), dtype=torch.bool).cuda(), seg_token_mask], dim=1, )
# Process hidden states
hidden_states, predicted_embeddings = self._process_hidden_states(
output_hidden_states, seg_token_mask, None, infer=True
)
image_embeddings = self.get_grounding_encoder_embs(grounding_enc_images)
# Generate and post-process masks
pred_masks = self._generate_and_postprocess_masks(
predicted_embeddings, image_embeddings, resize_list, orig_sizes, infer=True
)
return generated_output_ids, pred_masks