-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
82 lines (62 loc) · 2.51 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
########################################################
## This file contains the code for the Flask web app ##
########################################################
import torch
from utils.model_utils import load_model
from utils.data_utils import return_kmer, is_dna_sequence
from flask import Flask, render_template, request
app = Flask(__name__, template_folder="templates")
model_config = {
"model_path": r"C:\Users\Lenovo\Desktop\gittest\Virus-DNA-classification-BERT\results\classification\model", # Path to the trained model
"num_classes": 6,
}
model, tokenizer, device = load_model(model_config, return_model=True)
# Dictionary to convert the predicted class by the model to the class name
class_names_dic = {
1: "SARS-COV-1",
2: "MERS",
3: "SARS-COV-2",
4: "Ebola ",
5: "Dengue",
6: "Influenza",
}
KMER = 3
SEQ_MAX_LEN = 30000
def huggingface_predict(input):
"""
The input is passed to this function and the model makes a prediction
Parameters
----------
input : str
The input sequence to be classified
Returns
-------
predicted_class : int
The predicted class of the input sequence
"""
# Check if the input sequence is a DNA sequence
if not is_dna_sequence(input):
return "Invalid Input. Please enter your sequence in upper case", 0
kmer_seq = return_kmer(input, K=KMER)
# Tokenize the input sequence
inputs = tokenizer(kmer_seq, padding=True, truncation=True, return_tensors="pt")
inputs = inputs.to(device)
# Pass the tokenized inputs through the model to make a prediction
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits).item() + 1 # add 1 to convert from 0-indexed to 1-indexed classes
prediction_probs = torch.softmax(outputs.logits, dim=1).tolist()[0]
prediction_probability = prediction_probs[predicted_class - 1]
prediction_probability = round(prediction_probability, 3) * 100
# Convert the predicted class to the class name
predicted_class = class_names_dic[predicted_class]
return predicted_class, prediction_probability
@app.route('/')
def home():
return render_template('index.html')
@app.route('/predict', methods=['POST']) # handle the post request from the form in index.html
def predict():
input = request.form['input_sequence']
prediction, probability = huggingface_predict(input)
return render_template('index.html', prediction=prediction, probability=probability)
if __name__ == '__main__':
app.run(debug=True)