eat-classifiers

English question answer-type (EAT) classifiers trained on the TigreGotico/EAT dataset (30,017 questions, 53 fine-grained labels across 7 TREC categories).

Two-stage inference (eat7 gates eat53) achieves 93.4% macro F1 on the test set.

Used by little_questions.

Label taxonomy

7 main categories, 53 sub-types:

Main Sub-types
ABBR abb, exp
BOOL yesno
DESC def, desc, manner, reason
ENTY animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word
HUM desc, gr, ind, title
LOC city, country, landmass, mount, other, state, water
NUM code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight

Models

File Input variant Output[1]
eat53_logreg_EN_0.9.0.onnx punctuated (written) decision score
eat53_sgd_EN_0.9.0.onnx punctuated (written) decision score
eat53_svm_EN_0.9.0.onnx punctuated (written) decision score
eat53_svm_cal_EN_0.9.0.onnx punctuated (written) calibrated probability
eat53_svm_cal_unpunct_EN_0.9.0.onnx unpunctuated (ASR/voice) calibrated probability
eat7_logreg_EN_0.9.0.onnx punctuated (written) decision score
eat7_sgd_EN_0.9.0.onnx punctuated (written) decision score
eat7_svm_EN_0.9.0.onnx punctuated (written) decision score
eat7_svm_cal_EN_0.9.0.onnx punctuated (written) calibrated probability
eat7_svm_cal_unpunct_EN_0.9.0.onnx unpunctuated (ASR/voice) calibrated probability

Both punctuated and unpunctuated variants are provided. Use the unpunctuated (_unpunct) model for ASR / voice assistant input.

Two-stage inference

import onnxruntime as rt, numpy as np, json

sess7  = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx")
sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx")
classes7  = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"])
classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"])
main_of_53 = [c.split(":")[0] for c in classes53]

def classify(text):
    inp = np.array([text], dtype=object)
    main = classes7[int(sess7.run(None, {"input": inp})[0][0])]
    _, probs = sess53.run(None, {"input": inp})
    row = probs[0].copy()
    for j, m in enumerate(main_of_53):
        if m != main:
            row[j] = 0.0
    row /= row.sum()
    return classes53[int(np.argmax(row))], float(row.max())

print(classify("Who invented the telephone?"))  # ('HUM:ind', 0.96)

Benchmarks

Full results: BENCHMARKS.md

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train TigreGotico/eat-classifiers

Collection including TigreGotico/eat-classifiers