/*
 * Decompiled with CFR 0.152.
 */
package boofcv.processing;

import boofcv.abst.scene.ImageClassifier;
import boofcv.factory.scene.ClassifierAndSource;
import boofcv.processing.ConvertProcessing;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.io.DeepBoofDataBaseOps;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import processing.core.PImage;

public class SimpleImageClassification {
    ImageClassifier<Planar<GrayF32>> classifier;
    List<String> sources;
    File path;
    Planar<GrayF32> boofImage = new Planar(GrayF32.class, 1, 1, 3);
    boolean modelLoaded = false;

    public SimpleImageClassification(ClassifierAndSource cs) {
        this.classifier = cs.getClassifier();
        this.sources = cs.getSource();
    }

    public void loadModel(String location, boolean download) {
        if (location == null) {
            location = "download_data";
        }
        try {
            if (download) {
                location = DeepBoofDataBaseOps.downloadModel(this.sources, (File)new File(location)).getPath();
            }
            this.classifier.loadModel(new File(location));
            this.modelLoaded = true;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String classify(PImage image) {
        if (!this.modelLoaded) {
            throw new RuntimeException("Need to download and load the model");
        }
        this.boofImage.reshape(image.width, image.height);
        ConvertProcessing.convertFromRGB(image, this.boofImage);
        this.classifier.classify(this.boofImage);
        List categories = this.classifier.getCategories();
        return (String)categories.get(this.classifier.getBestResult());
    }

    public List<Score> getAllScores() {
        List categories = this.classifier.getCategories();
        ArrayList<Score> scores = new ArrayList<Score>();
        for (ImageClassifier.Score s : this.classifier.getAllResults()) {
            Score a = new Score();
            a.score = s.score;
            a.category = (String)categories.get(s.category);
            scores.add(a);
        }
        return scores;
    }

    public List<String> getCategories() {
        return this.classifier.getCategories();
    }

    public static class Score {
        public String category;
        public double score;
    }
}

