|
| 1 | +--- |
| 2 | +title: "Florence-2 - Vision Foundation Model - Examples" |
| 3 | +date: 2024-06-25T00:08:25+01:00 |
| 4 | +description: Examples and usage of Florence-2 Vision Model |
| 5 | +menu: |
| 6 | + sidebar: |
| 7 | + name: Florence-2 LVM |
| 8 | + identifier: florence |
| 9 | + parent: computer_vision |
| 10 | + weight: 9 |
| 11 | +hero: images/florence-2-lvm-computer-vision-exploration_28_3.png |
| 12 | +tags: ["Deep Learning", "Computer Vision", "Machine Learning"] |
| 13 | +categories: ["Computer Vision"] |
| 14 | +--- |
| 15 | +## Install dependencies |
| 16 | + |
| 17 | +Type the following command to install possible needed dependencies (especially if the inference is performed on the CPU) |
| 18 | + |
| 19 | +```python |
| 20 | +%pip install einops flash_attn |
| 21 | +``` |
| 22 | + |
| 23 | +In Kaggle, `transformers` and `torch` are already installed. Otherwise you also need to install them on your local PC. |
| 24 | + |
| 25 | +## Import Libraries |
| 26 | + |
| 27 | +```python |
| 28 | +from transformers import AutoProcessor, AutoModelForCausalLM |
| 29 | +from PIL import Image |
| 30 | +import requests |
| 31 | +import copy |
| 32 | +import torch |
| 33 | +%matplotlib inline |
| 34 | +``` |
| 35 | + |
| 36 | +## Import the model |
| 37 | + |
| 38 | +We can choose *Florence-2-large* or *Florence-2-large-ft* (fine-tuned). |
| 39 | + |
| 40 | +```python |
| 41 | +model_id = 'microsoft/Florence-2-large-ft' |
| 42 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 43 | +print(device) |
| 44 | +model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval() |
| 45 | +model = model.to(device) # put the model on the available GPU |
| 46 | +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
| 47 | + |
| 48 | +``` |
| 49 | + |
| 50 | +## Define inference function |
| 51 | + |
| 52 | +```python |
| 53 | +def run_inference(task_prompt, text_input=None): |
| 54 | + if text_input is None: |
| 55 | + prompt = task_prompt |
| 56 | + else: |
| 57 | + prompt = task_prompt + text_input |
| 58 | + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) |
| 59 | + generated_ids = model.generate( |
| 60 | + input_ids=inputs["input_ids"], |
| 61 | + pixel_values=inputs["pixel_values"], |
| 62 | + max_new_tokens=1024, |
| 63 | + early_stopping=False, |
| 64 | + do_sample=False, |
| 65 | + num_beams=3, |
| 66 | + ) |
| 67 | + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| 68 | + parsed_answer = processor.post_process_generation( |
| 69 | + generated_text, |
| 70 | + task=task_prompt, |
| 71 | + image_size=(image.width, image.height) |
| 72 | + ) |
| 73 | + |
| 74 | + return parsed_answer |
| 75 | +``` |
| 76 | + |
| 77 | +## Get image link |
| 78 | + |
| 79 | +```python |
| 80 | +image_url = "http://lerablog.org/wp-content/uploads/2013/06/two-cars.jpg" # an arbitrary image link or filepath can be inserted here |
| 81 | +image = Image.open(requests.get(image_url, stream=True).raw) |
| 82 | +image |
| 83 | +``` |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | +## Run pre-defined tasks without additional inputs |
| 88 | + |
| 89 | +### Caption |
| 90 | + |
| 91 | +```python |
| 92 | +task_prompt = '<CAPTION>' |
| 93 | +run_inference(task_prompt) |
| 94 | +``` |
| 95 | + |
| 96 | +> {'<CAPTION>': 'Two sports cars parked next to each other on a road.'} |
| 97 | +
|
| 98 | +```python |
| 99 | +task_prompt = '<DETAILED_CAPTION>' |
| 100 | +run_inference(task_prompt) |
| 101 | +``` |
| 102 | + |
| 103 | +> {'<DETAILED_CAPTION>': 'In this image we can see two cars on the road. In the background, we can also see water, hills and the sky.'} |
| 104 | +
|
| 105 | +```python |
| 106 | +task_prompt = '<MORE_DETAILED_CAPTION>' |
| 107 | +run_inference(task_prompt) |
| 108 | +``` |
| 109 | + |
| 110 | +> {'<MORE_DETAILED_CAPTION>': 'There are two cars parked on the street. There is water behind the cars. There are mountains behind the water. The cars are yellow and black. '} |
| 111 | +
|
| 112 | +## Object Detection |
| 113 | + |
| 114 | +```python |
| 115 | +task_prompt = '<OD>' |
| 116 | +results = run_inference(task_prompt) |
| 117 | +print(results) |
| 118 | +``` |
| 119 | + |
| 120 | +<!-- {'<OD>': {'bboxes': [[336.1050109863281, 115.95000457763672, 599.4450073242188, 248.5500030517578], [18.584999084472656, 117.45000457763672, 304.6050109863281, 236.25001525878906], [113.08499908447266, 177.15000915527344, 172.30499267578125, 235.95001220703125], [404.1449890136719, 187.95001220703125, 454.54498291015625, 248.25001525878906], [336.1050109863281, 176.25, 380.2049865722656, 235.95001220703125], [26.774999618530273, 173.85000610351562, 73.3949966430664, 228.15000915527344], [244.125, 216.15000915527344, 291.375, 231.15000915527344], [546.5250244140625, 236.5500030517578, 588.7349853515625, 245.85000610351562], [481.635009765625, 148.35000610351562, 509.3550109863281, 157.65000915527344]], 'labels': ['car', 'car', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel']}} --> |
| 121 | + |
| 122 | +```python |
| 123 | +import matplotlib.pyplot as plt |
| 124 | +import matplotlib.patches as patches |
| 125 | +def plot_bbox(image, data): |
| 126 | + # Create a figure and axes |
| 127 | + fig, ax = plt.subplots() |
| 128 | + |
| 129 | + # Display the image |
| 130 | + ax.imshow(image) |
| 131 | + |
| 132 | + # Plot each bounding box |
| 133 | + for bbox, label in zip(data['bboxes'], data['labels']): |
| 134 | + # Unpack the bounding box coordinates |
| 135 | + x1, y1, x2, y2 = bbox |
| 136 | + # Create a Rectangle patch |
| 137 | + rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none') |
| 138 | + # Add the rectangle to the Axes |
| 139 | + ax.add_patch(rect) |
| 140 | + # Annotate the label |
| 141 | + plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5)) |
| 142 | + |
| 143 | + # Remove the axis ticks and labels |
| 144 | + ax.axis('off') |
| 145 | + |
| 146 | + # Show the plot |
| 147 | + plt.show() |
| 148 | +``` |
| 149 | + |
| 150 | +--- |
| 151 | + |
| 152 | +```python |
| 153 | +plot_bbox(image, results['<OD>']) |
| 154 | +``` |
| 155 | + |
| 156 | + |
| 157 | + |
| 158 | +## Dense Region Caption |
| 159 | + |
| 160 | +```python |
| 161 | +task_prompt = '<DENSE_REGION_CAPTION>' |
| 162 | +results = run_inference(task_prompt) |
| 163 | +dense_region_res = results |
| 164 | +print(results) |
| 165 | +``` |
| 166 | + |
| 167 | +> {'<DENSE_REGION_CAPTION>': {'bboxes': [[334.8450012207031, 115.95000457763672, 599.4450073242188, 248.5500030517578], [18.584999084472656, 117.45000457763672, 304.6050109863281, 236.> 25001525878906], [113.71499633789062, 177.15000915527344, 172.30499267578125, 235.95001220703125], [404.1449890136719, 187.95001220703125, 453.9150085449219, 248.25001525878906], [26.> 774999618530273, 173.85000610351562, 73.3949966430664, 228.15000915527344], [336.1050109863281, 176.25, 380.2049865722656, 235.95001220703125], [244.125, 216.45001220703125, |
| 168 | +> 290.7449951171875, 230.85000610351562], [546.5250244140625, 236.5500030517578, 588.7349853515625, 245.85000610351562], [481.635009765625, 148.35000610351562, 509.3550109863281, |
| 169 | +> 157.65000915527344]], 'labels': ['yellow sports car', 'sports car', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel', 'wheel']}} |
| 170 | +
|
| 171 | +```python |
| 172 | +plot_bbox(image, results['<DENSE_REGION_CAPTION>']) |
| 173 | +``` |
| 174 | + |
| 175 | +<!-- {{< img src="/images/images/florence-2-lvm-computer-vision-exploration_19_0.png" align="center" title="Histogram">}} --> |
| 176 | + |
| 177 | + |
| 178 | + |
| 179 | +## Phrase Grounding |
| 180 | + |
| 181 | +```python |
| 182 | +task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>' |
| 183 | +results = run_inference(task_prompt, text_input="Yellow car with islands in background") |
| 184 | +print(results) |
| 185 | +plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>']) |
| 186 | +``` |
| 187 | + |
| 188 | +{'<CAPTION_TO_PHRASE_GROUNDING>': {'bboxes': [[335.4750061035156, 115.6500015258789, 601.9649658203125, 250.35000610351562], [0.3149999976158142, 12.15000057220459, 629.0549926757812, 103.6500015258789]], 'labels': ['Yellow car', 'islands']}} |
| 189 | + |
| 190 | + |
| 191 | +## Segmentation |
| 192 | + |
| 193 | +```python |
| 194 | +task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>' |
| 195 | +results = run_inference(task_prompt, text_input="yellow car and island") |
| 196 | +print(results) |
| 197 | +``` |
| 198 | + |
| 199 | +<!-- {'<REFERRING_EXPRESSION_SEGMENTATION>': {'polygons': [[[348.07501220703125, 149.85000610351562, 364.4549865722656, 147.75, 387.135009765625, 126.75000762939453, 414.8550109863281, 118.35000610351562, 473.44500732421875, 116.25000762939453, 508.7250061035156, 120.45000457763672, 538.3350219726562, 147.15000915527344, 545.2650146484375, 145.0500030517578, 557.2349853515625, 149.25, 557.864990234375, 156.15000915527344, 547.7849731445312, 159.75, 572.9849853515625, 169.65000915527344, 588.7349853515625, 178.0500030517578, 596.9249877929688, 202.65000915527344, 599.4450073242188, 223.65000915527344, 596.9249877929688, 236.25001525878906, 588.7349853515625, 237.75001525878906, 579.2849731445312, 246.15000915527344, 553.4550170898438, 246.15000915527344, 547.1549682617188, 239.85000610351562, 450.135009765625, 239.85000610351562, 438.79498291015625, 248.25001525878906, 419.2649841308594, 248.25001525878906, 407.92498779296875, 237.75001525878906, 406.6650085449219, 229.35000610351562, 378.94500732421875, 225.15000915527344, 376.42498779296875, 233.5500030517578, 348.07501220703125, 235.65000915527344, 339.2550048828125, 225.15000915527344, 336.1050109863281, 198.45001220703125, 336.7349853515625, 175.95001220703125, 343.6650085449219, 159.75]]], 'labels': ['']}} --> |
| 200 | + |
| 201 | +```python |
| 202 | +from PIL import Image, ImageDraw, ImageFont |
| 203 | +import random |
| 204 | +import numpy as np |
| 205 | + |
| 206 | +colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', |
| 207 | + 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] |
| 208 | + |
| 209 | +def draw_polygons(image, prediction, fill_mask=False): |
| 210 | + """ |
| 211 | + Draws segmentation masks with polygons on an image. |
| 212 | + |
| 213 | + Parameters: |
| 214 | + - image_path: Path to the image file. |
| 215 | + - prediction: Dictionary containing 'polygons' and 'labels' keys. |
| 216 | + 'polygons' is a list of lists, each containing vertices of a polygon. |
| 217 | + 'labels' is a list of labels corresponding to each polygon. |
| 218 | + - fill_mask: Boolean indicating whether to fill the polygons with color. |
| 219 | + """ |
| 220 | + # Load the image |
| 221 | + |
| 222 | + draw = ImageDraw.Draw(image) |
| 223 | + |
| 224 | + |
| 225 | + # Set up scale factor if needed (use 1 if not scaling) |
| 226 | + scale = 1 |
| 227 | + |
| 228 | + # Iterate over polygons and labels |
| 229 | + for polygons, label in zip(prediction['polygons'], prediction['labels']): |
| 230 | + color = random.choice(colormap) |
| 231 | + fill_color = random.choice(colormap) if fill_mask else None |
| 232 | + |
| 233 | + for _polygon in polygons: |
| 234 | + _polygon = np.array(_polygon).reshape(-1, 2) |
| 235 | + if len(_polygon) < 3: |
| 236 | + print('Invalid polygon:', _polygon) |
| 237 | + continue |
| 238 | + |
| 239 | + _polygon = (_polygon * scale).reshape(-1).tolist() |
| 240 | + |
| 241 | + # Draw the polygon |
| 242 | + if fill_mask: |
| 243 | + draw.polygon(_polygon, outline=color, fill=fill_color) |
| 244 | + else: |
| 245 | + draw.polygon(_polygon, outline=color) |
| 246 | + |
| 247 | + # Draw the label text |
| 248 | + draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color) |
| 249 | + |
| 250 | + # Save or display the image |
| 251 | + #image.show() # Display the image |
| 252 | + display(image) |
| 253 | +``` |
| 254 | + |
| 255 | +```python |
| 256 | +output_image = copy.deepcopy(image) |
| 257 | +draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True) |
| 258 | +``` |
| 259 | + |
| 260 | + |
| 261 | + |
| 262 | +### Regions Segmentation |
| 263 | + |
| 264 | +```python |
| 265 | +def bbox_to_loc(bbox): |
| 266 | + # bbox position need to be rescaled from 0 to 999. the coordinates are x1_y1_x2_y2 |
| 267 | + return f"<loc_{int(bbox[0]*999/width)}><loc_{int(bbox[1]*999/height)}><loc_{int(bbox[2]*999/width)}><loc_{int(bbox[3]*999/height)}>" |
| 268 | + |
| 269 | +with torch.no_grad(): |
| 270 | + torch.cuda.empty_cache() |
| 271 | +``` |
| 272 | + |
| 273 | +```python |
| 274 | +output_image = copy.deepcopy(image) |
| 275 | +height, width = image.height, image.width |
| 276 | +task_prompt = '<REGION_TO_SEGMENTATION>' |
| 277 | + |
| 278 | +for bbox in dense_region_res['<DENSE_REGION_CAPTION>']['bboxes'][:]: |
| 279 | + print(bbox_to_loc(bbox)) |
| 280 | + results = run_inference(task_prompt, text_input=bbox_to_loc(bbox)) |
| 281 | + draw_polygons(output_image, results[task_prompt], fill_mask=True) |
| 282 | + |
| 283 | +plot_bbox(output_image, dense_region_res['<DENSE_REGION_CAPTION>']) |
| 284 | +``` |
| 285 | + |
| 286 | +> <loc_530><loc_386><loc_950><loc_827> |
| 287 | +>  |
| 288 | +
|
| 289 | +> <loc_29><loc_391><loc_483><loc_786> |
| 290 | +>  |
| 291 | +
|
| 292 | +> <loc_180><loc_589><loc_273><loc_785> |
| 293 | +>  |
| 294 | +
|
| 295 | +> <loc_640><loc_625><loc_719><loc_826> |
| 296 | +>  |
| 297 | +
|
| 298 | +> <loc_42><loc_578><loc_116><loc_759> |
| 299 | +>  |
| 300 | +
|
| 301 | +> <loc_532><loc_586><loc_602><loc_785> |
| 302 | +>  |
| 303 | +
|
| 304 | +> <loc_387><loc_720><loc_461><loc_768> |
| 305 | +>  |
| 306 | +
|
| 307 | +> <loc_866><loc_787><loc_933><loc_818> |
| 308 | +>  |
| 309 | +
|
| 310 | +> <loc_763><loc_494><loc_807><loc_524> |
| 311 | +>  |
| 312 | +
|
| 313 | +>  |
| 314 | +
|
| 315 | +## OCR |
| 316 | + |
| 317 | +```python |
| 318 | +url = "https://m.media-amazon.com/images/I/510sf0pRTlL.jpg" |
| 319 | +image = Image.open(requests.get(url, stream=True).raw).convert('RGB') |
| 320 | +image |
| 321 | +``` |
| 322 | + |
| 323 | + |
| 324 | + |
| 325 | +```python |
| 326 | +task_prompt = '<OCR_WITH_REGION>' |
| 327 | +results = run_inference(task_prompt) |
| 328 | +print(results) |
| 329 | +``` |
| 330 | + |
| 331 | +> {'<OCR_WITH_REGION>': {'quad_boxes': [[143.8125, 146.25, 280.9624938964844, 146.25, 280.9624938964844, 172.25, 143.8125, 172.25], [134.0625, 176.25, 281.9375, 176.25, 281.9375, 202.25, 134.0625, 202.25], [172.73748779296875, 206.25, 284.2124938964844, 206.25, 284.2124938964844, 216.25, 172.73748779296875, 216.25], [150.3125, 238.25, 281.9375, 238.25, 281.9375, 247.25, 150.3125, 247.25], [139.58749389648438, 254.25, 284.2124938964844, 254.25, 284.2124938964844, 277.75, 139.58749389648438, 277.75], [133.08749389648438, 283.75, 285.1875, 283.75, 285.1875, 307.75, 133.08749389648438, 307.75], [140.5625, 312.75, 281.9375, 312.75, 281.9375, 320.75, 140.5625, 320.75]], 'labels': ['</s>**QUANTUM', 'MECHANICS', '(Non-relativistic Theory)', 'Course of Theoretical Phyias Volume 3', 'L.D. LANDAU', 'E.M. LIFSHITZ', 'Initiute of Physical Problems, USSR Academy of**']}} |
| 332 | +
|
| 333 | +The overall extracted text from the image is very close to the original one. |
| 334 | +```python |
| 335 | +def draw_ocr_bboxes(image, prediction): |
| 336 | + scale = 1 |
| 337 | + draw = ImageDraw.Draw(image) |
| 338 | + bboxes, labels = prediction['quad_boxes'], prediction['labels'] |
| 339 | + for box, label in zip(bboxes, labels): |
| 340 | + color = random.choice(colormap) |
| 341 | + new_box = (np.array(box) * scale).tolist() |
| 342 | + draw.polygon(new_box, width=3, outline=color) |
| 343 | + draw.text((new_box[0]-8, new_box[1]-10), |
| 344 | + "{}".format(label), |
| 345 | + align="right", |
| 346 | + fill=color) |
| 347 | + display(image) |
| 348 | + |
| 349 | +output_image = copy.deepcopy(image) |
| 350 | +draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>']) |
| 351 | +``` |
| 352 | + |
| 353 | + |
0 commit comments