1
+ import glob
1
2
import os
2
- import numpy as np
3
+
3
4
from PIL import Image
4
- from capgen import CaptionGenerator
5
- from flask import Flask , request , render_template , request , redirect
6
5
from elasticsearch import Elasticsearch
7
6
from elasticsearch .helpers import bulk
7
+ from flask import Flask , render_template , request , Response
8
8
from werkzeug .utils import secure_filename
9
- import glob
9
+ import json
10
+
11
+ from capgen import CaptionGenerator
10
12
11
13
os .environ ['CUDA_VISIBLE_DEVICES' ] = ''
12
14
es = Elasticsearch ()
13
15
gencap = CaptionGenerator ()
14
16
17
+
15
18
def description_search (query ):
16
19
global es
17
20
results = es .search (
18
21
index = "desearch" ,
19
22
body = {
20
23
"size" : 20 ,
21
24
"query" : {
22
- "match" : {"description" : query }
25
+ "match" : {"description" : query }
23
26
}
24
- })
25
- hitCount = results ['hits' ]['total' ]
27
+ })
28
+ hitCount = results ['hits' ]['total' ][ 'value' ]
26
29
print (results )
27
30
28
31
if hitCount > 0 :
29
32
if hitCount is 1 :
30
- print (str (hitCount ),' result' )
33
+ print (str (hitCount ), ' result' )
31
34
else :
32
35
print (str (hitCount ), 'results' )
33
- answers = []
36
+ answers = []
34
37
max_score = results ['hits' ]['max_score' ]
35
38
36
39
if max_score >= 0.35 :
37
40
for hit in results ['hits' ]['hits' ]:
38
41
if hit ['_score' ] > 0.5 * max_score :
39
42
desc = hit ['_source' ]['description' ]
40
43
imgurl = hit ['_source' ]['imgurl' ]
41
- answers .append ([imgurl ,desc ])
44
+ answers .append ([imgurl , desc ])
42
45
else :
43
46
answers = []
44
47
return answers
45
48
49
+
46
50
app = Flask (__name__ )
47
- app .config ['UPLOAD_FOLDER' ] = os .path .join ('static' ,'database' )
48
- app .config ['TEMP_UPLOAD_FOLDER' ] = os .path .join ('static' ,'uploads' )
49
- app .config ['ALLOWED_EXTENSIONS' ] = set (['jpg' ,'jpeg' ,'png' ])
51
+ app .config ['UPLOAD_FOLDER' ] = os .path .join ('static' , 'database' )
52
+ app .config ['TEMP_UPLOAD_FOLDER' ] = os .path .join ('static' , 'uploads' )
53
+ app .config ['ALLOWED_EXTENSIONS' ] = set (['jpg' , 'jpeg' , 'png' ])
50
54
51
55
52
56
def allowed_file (filename ):
@@ -58,11 +62,13 @@ def allowed_file(filename):
58
62
def index ():
59
63
return render_template ('home.html' )
60
64
65
+
61
66
@app .route ('/search' , methods = ['GET' , 'POST' ])
62
67
def search ():
63
68
global gencap
64
69
if request .method == 'POST' :
65
- if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (request .files ['query_img' ].filename ):
70
+ if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (
71
+ request .files ['query_img' ].filename ):
66
72
return render_template ('search.html' )
67
73
file = request .files ['query_img' ]
68
74
img = Image .open (file .stream ) # PIL image
@@ -77,12 +83,32 @@ def search():
77
83
else :
78
84
return render_template ('search.html' )
79
85
86
+
87
+ @app .route ('/api/search' , methods = ['POST' ])
88
+ def api_search ():
89
+ global gencap
90
+ if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (
91
+ request .files ['query_img' ].filename ):
92
+ return Response (response = json .dumps ({'success' : False , 'message' : 'Uploaded image is invalid or not allowed' }),
93
+ status = 400 , mimetype = "application/json" )
94
+ file = request .files ['query_img' ]
95
+ img = Image .open (file .stream ) # PIL image
96
+ uploaded_img_path = os .path .join (app .config ['TEMP_UPLOAD_FOLDER' ], file .filename )
97
+ img .save (uploaded_img_path )
98
+ query = gencap .get_caption (uploaded_img_path )
99
+ answers = description_search (query )
100
+
101
+ return Response (response = json .dumps ({'success' : True , 'answers' : answers }),
102
+ status = 200 , mimetype = "application/json" )
103
+
104
+
80
105
@app .route ('/database' )
81
106
def database ():
82
- images = glob .glob (os .path .join (app .config ['UPLOAD_FOLDER' ],'*' ))
83
- return render_template ('database.html' , database_images = images )
107
+ images = glob .glob (os .path .join (app .config ['UPLOAD_FOLDER' ], '*' ))
108
+ return render_template ('database.html' , database_images = images )
109
+
84
110
85
- @app .route ('/upload' , methods = ['GET' ,'POST' ])
111
+ @app .route ('/upload' , methods = ['GET' , 'POST' ])
86
112
def upload ():
87
113
if request .method == 'POST' :
88
114
if 'photos' not in request .files :
@@ -94,24 +120,42 @@ def upload():
94
120
file_path = os .path .join (app .config ['UPLOAD_FOLDER' ], filename )
95
121
file .save (file_path )
96
122
cap = gencap .get_caption (file_path )
97
- doc = {'imgurl' : file_path , 'description' :cap }
123
+ doc = {'imgurl' : file_path , 'description' : cap }
98
124
actions .append (doc )
99
- bulk (es ,actions ,index = "desearch" ,doc_type = "json" )
125
+ bulk (es , actions , index = "desearch" , doc_type = "json" )
100
126
return render_template ('database.html' )
101
127
102
- @app .route ('/caption' , methods = ['GET' ,'POST' ])
128
+
129
+ @app .route ('/caption' , methods = ['GET' , 'POST' ])
103
130
def caption ():
104
131
if request .method == 'POST' :
105
- if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (request .files ['query_img' ].filename ):
132
+ if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (
133
+ request .files ['query_img' ].filename ):
106
134
return render_template ('caption.html' )
107
135
file = request .files ['query_img' ]
108
136
img = Image .open (file .stream ) # PIL image
109
137
uploaded_img_path = os .path .join (app .config ['TEMP_UPLOAD_FOLDER' ], file .filename )
110
138
img .save (uploaded_img_path )
111
139
cap = gencap .get_caption (uploaded_img_path )
112
- return render_template ('caption.html' , caption = cap , query_path = uploaded_img_path )
140
+ return render_template ('caption.html' , caption = cap , query_path = uploaded_img_path )
113
141
else :
114
142
return render_template ('caption.html' )
115
143
116
- if __name__ == "__main__" :
144
+
145
+ @app .route ('/api/caption' , methods = ['POST' ])
146
+ def caption_api ():
147
+ if 'query_img' not in request .files or request .files ['query_img' ].filename == '' or not allowed_file (
148
+ request .files ['query_img' ].filename ):
149
+ return Response (response = json .dumps ({'success' : False , 'message' : 'Uploaded image is invalid or not allowed' }),
150
+ status = 400 , mimetype = "application/json" )
151
+ file = request .files ['query_img' ]
152
+ img = Image .open (file .stream ) # PIL image
153
+ uploaded_img_path = os .path .join (app .config ['TEMP_UPLOAD_FOLDER' ], file .filename )
154
+ img .save (uploaded_img_path )
155
+ cap = gencap .get_caption (uploaded_img_path )
156
+ return Response (response = json .dumps ({'success' : True , 'caption' : cap }),
157
+ status = 200 , mimetype = "application/json" )
158
+
159
+
160
+ if __name__ == "__main__" :
117
161
app .run ("127.0.0.1" , debug = True )
0 commit comments