Run a Catgirl classifier model
Python··26 views
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import json
device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("catgirl_detector.pth", map_location=device))
model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def predict(image_path):
img = Image.open(image_path).convert("RGB")
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
probabilities = F.softmax(output, dim=1)[0]
catgirl_prob = probabilities[0].item()
return json.dumps({"catGirlProbability": round(catgirl_prob, 4)})
image_path = r"" # some image here
print(predict(image_path))