Run a Catgirl classifier model

Python··26 views
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json

device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("catgirl_detector.pth", map_location=device))
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

def predict(image_path):
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img)
        probabilities = F.softmax(output, dim=1)[0]
        catgirl_prob = probabilities[0].item()
    
    return json.dumps({"catGirlProbability": round(catgirl_prob, 4)})

image_path = r"" # some image here
print(predict(image_path))