-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
31 lines (24 loc) · 1.07 KB
/
server.py
File metadata and controls
31 lines (24 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
from flask import Flask, request, jsonify
from sentence_transformers.cross_encoder import CrossEncoder
app = Flask(__name__)
model = CrossEncoder("cross-encoder/stsb-distilroberta-base")
@app.route('/suggest', methods=['GET'])
def suggest():
query = request.args.get('query')
with open('courses.json', 'r', encoding='utf-8') as f:
courses = {i["tag"] + ". " + i["title"] + ". " + i["keywords"] + ". " + i["description"]: i for i in json.load(f)}
corpus = list(courses.keys())
ranks = model.rank(query, corpus)
interesting_courses = []
threshold = 0.4
if ranks[9]['score'] > threshold:
for rank in ranks[:12]:
if rank['score'] < threshold:
break
interesting_courses.append(courses[corpus[rank['corpus_id']]]['title'])
else:
interesting_courses = [courses[corpus[rank['corpus_id']]]['title'] for rank in ranks[:3]]
return jsonify({"interest_courses": interesting_courses})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)