1from __future__ import annotations
2
3import logging
4import os
5from contextlib import asynccontextmanager
6
7import torch
8from fastapi import FastAPI, Request
9from prometheus_client import Counter, Histogram, make_asgi_app
10
11from app.models.classifier import Classifier
12from app.schemas import PredictRequest, PredictResponse
13from app.metrics import COUNT_DESC
14
15logger = logging.getLogger(__name__)
16INFERENCE_LATENCY = Histogram(
17 "inference_seconds", "Wall time per inference", ["model"]
18)
19INFERENCE_COUNT = Counter("inference_total", COUNT_DESC, ["model"])
20
21
22@asynccontextmanager
23async def lifespan(app: FastAPI):
24 device = "cuda" if torch.cuda.is_available() else "cpu"
25 app.state.classifier = Classifier.from_pretrained(
26 os.environ["MODEL_PATH"], device=device,
27 )
28 logger.info("inference-api ready on %s", device)
29 yield
30 del app.state.classifier
31 if device == "cuda":
32 torch.cuda.empty_cache()
33
34
35app = FastAPI(lifespan=lifespan, title="inference-api")
36app.mount("/metrics", make_asgi_app())
37
38
39@app.post("/v1/predict", response_model=PredictResponse)
40async def predict(req: PredictRequest, request: Request) -> PredictResponse:
41 clf = request.app.state.classifier
42 with INFERENCE_LATENCY.labels(model=clf.name).time():
43 out = await clf.predict(req.input)
44 INFERENCE_COUNT.labels(model=clf.name).inc()
45 return PredictResponse(label=out.label, score=out.score)
46
47
48@app.get("/healthz")
49async def healthz() -> dict[str, str]:
50 return {"status": "ok"}