using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows;
using System.Windows.Forms;
using System.Windows.Controls;
using System.Windows.Data;
using System.Windows.Documents;
using System.Windows.Input;
using System.Windows.Media;
using System.Windows.Media.Imaging;
using System.Windows.Navigation;
using System.Windows.Shapes;
using TensorFlow;

namespace TensorTest
{
    /// 
    /// Interaction logic for MainWindow.xaml
    /// 
    public partial class MainWindow : Window
    {
        public const string modelfilename = "tensorflow_inception_graph.pb";
        public const string testimage = "test.jpg"; //image of a parrot
        public const string labelfile = "imagenet_comp_graph_label_strings.txt";

        public MainWindow()
        {
            InitializeComponent();
        }

        public static TFTensor CreateTensorFromImageFile(string file, TFDataType destinationDataType = TFDataType.Float)
        {
            var contents = File.ReadAllBytes(file);

            // DecodeJpeg uses a scalar String-valued tensor as input.
            var tensor = TFTensor.CreateString(contents);

            TFOutput input, output;

            // Construct a graph to normalize the image
            using (var graph = ConstructGraphToNormalizeImage(out input, out output, destinationDataType))
            {
                // Execute that graph to normalize this one image
                using (var session = new TFSession(graph))
                {
                    var normalized = session.Run(
                        inputs: new[] { input },
                        inputValues: new[] { tensor },
                        outputs: new[] { output });

                    return normalized[0];
                }
            }
        }

        private static TFGraph ConstructGraphToNormalizeImage(out TFOutput input, out TFOutput output, TFDataType destinationDataType = TFDataType.Float)
        {
            // Some constants specific to the pre-trained model at:
            // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
            //
            // - The model was trained after with images scaled to 224x224 pixels.
            // - The colors, represented as R, G, B in 1-byte each were converted to
            //   float using (value - Mean)/Scale.

            const int W = 224;
            const int H = 224;
            const float Mean = 117;
            const float Scale = 1;

            var graph = new TFGraph();
            input = graph.Placeholder(TFDataType.String);

            output = graph.Cast(graph.Div(
                x: graph.Sub(
                    x: graph.ResizeBilinear(
                        images: graph.ExpandDims(
                            input: graph.Cast(
                                graph.DecodeJpeg(contents: input, channels: 3), DstT: TFDataType.Float),
                            dim: graph.Const(0, "make_batch")),
                        size: graph.Const(new int[] { W, H }, "size")),
                    y: graph.Const(Mean, "mean")),
                y: graph.Const(Scale, "scale")), destinationDataType);

            return graph;
        }

        static string basedir; //Holds our directory

        private void button_Click(object sender, RoutedEventArgs e)
        {
            string MFwithPath, LFwithPath, JPGwithPath;

            //Start our run
            //First, let's load the model we downloaded from: https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
            basedir = System.IO.Path.GetDirectoryName(System.Windows.Forms.Application.ExecutablePath); //Grab the current directory.
            textBox.Clear();           
            MFwithPath = basedir + "\\" + modelfilename; //Construct version with full path
            LFwithPath = basedir + "\\" + labelfile;
            JPGwithPath = basedir + "\\" + testimage;

            textBox.AppendText(MFwithPath + "\n");
                

            if (File.Exists(MFwithPath) == true) {

                if  (File.Exists(JPGwithPath) == true)
                {
                    BitmapImage bi3 = new BitmapImage();
                    bi3.BeginInit();
                    bi3.UriSource = new Uri(JPGwithPath);
                    bi3.EndInit();
                    image.Source = bi3;                    
                }

                var graph = new TFGraph(); //This will throw a BadImage exception if you're not set to x64 in Build options.
                
                var model = System.IO.File.ReadAllBytes(MFwithPath);
                graph.Import(model, "");

                var session = new TFSession(graph);
                var labels = File.ReadAllLines(LFwithPath);

                var tensor = CreateTensorFromImageFile(JPGwithPath); //Imports the image into a format our NN can read

                var runner = session.GetRunner();
                runner.AddInput(graph["input"][0], tensor).Fetch(graph["output"][0]);
                var output = runner.Run();
                // output[0].Value() is a vector containing probabilities of
                // labels for each image in the "batch". The batch size was 1.
                // Find the most probably label index.

                var result = output[0];
                var rshape = result.Shape;
                if (result.NumDims != 2 || rshape[0] != 1)
                {
                    var shape = "";
                    foreach (var d in rshape)
                    {
                        shape += $"{d} ";
                    }
                    shape = shape.Trim();
                    textBox.AppendText("Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [" + shape.ToString() + "]");
                    Environment.Exit(1);
                }

                // You can get the data in two ways, as a multi-dimensional array, or arrays of arrays, 
                // code can be nicer to read with one or the other, pick it based on how you want to process
                // it
                bool jagged = true;

                var bestIdx = 0;
                float p = 0, best = 0;

                if (jagged)
                {
                    var probabilities = ((float[][])result.GetValue(jagged: true))[0]; //array of arrays approach
                    for (int i = 0; i < probabilities.Length; i++)
                    {
                        if (probabilities[i] > best)
                        {
                            bestIdx = i;
                            best = probabilities[i];
                        }
                    }

                }
                else
                {
                    var val = (float[,])result.GetValue(jagged: false); //multi-dimensional array approach

                    // Result is [1,N], flatten array
                    for (int i = 0; i < val.GetLength(1); i++)
                    {
                        if (val[0, i] > best)
                        {
                            bestIdx = i;
                            best = val[0, i];
                        }
                    }
                }

                float bestper = best * 100f;
                //textBox.AppendText("Best match: ["+bestIdx.ToString() + "] {best * 100.0}% {labels[bestIdx]}");
                textBox.AppendText("Best match: ["+bestIdx.ToString() + "] "+ bestper.ToString() + " " + labels[bestIdx].ToString());
             
            }
           
        } //File Exists
    }
}