Question Classifiers
Collection
taxonomy, datasets and baseline models for question type classification • 6 items • Updated
English question answer-type (EAT) classifiers trained on the TigreGotico/EAT dataset (30,017 questions, 53 fine-grained labels across 7 TREC categories).
Two-stage inference (eat7 gates eat53) achieves 93.4% macro F1 on the test set.
Used by little_questions.
7 main categories, 53 sub-types:
| Main | Sub-types |
|---|---|
ABBR |
abb, exp |
BOOL |
yesno |
DESC |
def, desc, manner, reason |
ENTY |
animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word |
HUM |
desc, gr, ind, title |
LOC |
city, country, landmass, mount, other, state, water |
NUM |
code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight |
| File | Input variant | Output[1] |
|---|---|---|
eat53_logreg_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat53_sgd_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat53_svm_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat53_svm_cal_EN_0.9.0.onnx |
punctuated (written) | calibrated probability |
eat53_svm_cal_unpunct_EN_0.9.0.onnx |
unpunctuated (ASR/voice) | calibrated probability |
eat7_logreg_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat7_sgd_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat7_svm_EN_0.9.0.onnx |
punctuated (written) | decision score |
eat7_svm_cal_EN_0.9.0.onnx |
punctuated (written) | calibrated probability |
eat7_svm_cal_unpunct_EN_0.9.0.onnx |
unpunctuated (ASR/voice) | calibrated probability |
Both punctuated and unpunctuated variants are provided.
Use the unpunctuated (_unpunct) model for ASR / voice assistant input.
import onnxruntime as rt, numpy as np, json
sess7 = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx")
sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx")
classes7 = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"])
classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"])
main_of_53 = [c.split(":")[0] for c in classes53]
def classify(text):
inp = np.array([text], dtype=object)
main = classes7[int(sess7.run(None, {"input": inp})[0][0])]
_, probs = sess53.run(None, {"input": inp})
row = probs[0].copy()
for j, m in enumerate(main_of_53):
if m != main:
row[j] = 0.0
row /= row.sum()
return classes53[int(np.argmax(row))], float(row.max())
print(classify("Who invented the telephone?")) # ('HUM:ind', 0.96)
Full results: BENCHMARKS.md