Learn by Directing AI
All materials

serve.py

pyserve.py
"""FastAPI serving endpoint for coffee yield prediction model."""

import json
import torch
import torch.nn as nn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn

from feature_pipeline import prepare_features


class YieldModel(nn.Module):
    """Simple feedforward neural network for yield prediction."""

    def __init__(self, input_size=8, hidden_sizes=(32, 16)):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], 1),
        )

    def forward(self, x):
        return self.network(x)


class PredictionRequest(BaseModel):
    """Input schema for prediction requests."""

    farm_id: str = Field(..., description="Farm identifier")
    temperature: float = Field(..., description="Temperature in Celsius")
    rainfall: float = Field(..., description="Rainfall in mm/day")
    soil_moisture: float = Field(..., description="Soil moisture percentage")
    humidity: float = Field(..., description="Humidity percentage")
    altitude: float = Field(..., description="Altitude in meters")


class PredictionResponse(BaseModel):
    """Output schema for prediction responses."""

    farm_id: str
    predicted_yield_kg: float


app = FastAPI(title="Finca Esperanza Yield Prediction")

# Load model at startup
model = YieldModel()
model.load_state_dict(torch.load("model.pt", map_location="cpu", weights_only=True))
model.eval()

# Load sensor schema for validation
with open("sensor-schema.json", "r") as f:
    sensor_schema = json.load(f)


@app.get("/health")
def health():
    """Basic health check."""
    return {"status": "healthy"}


@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
    """Generate yield prediction for a farm."""
    try:
        raw_data = {
            "farm_id": request.farm_id,
            "temperature": request.temperature,
            "rainfall": request.rainfall,
            "soil_moisture": request.soil_moisture,
            "humidity": request.humidity,
            "altitude": request.altitude,
        }

        features = prepare_features(raw_data)

        with torch.no_grad():
            prediction = model(features)
            predicted_yield = float(prediction.item())

        # Clamp to reasonable range (0 to 50000 kg per farm)
        predicted_yield = max(0.0, min(predicted_yield, 50000.0))

        return PredictionResponse(
            farm_id=request.farm_id,
            predicted_yield_kg=round(predicted_yield, 2),
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)