import React, { useState, useCallback, useEffect, useRef } from 'react';
import Modal from 'react-modal';
import Webcam from 'react-webcam';
import classnames from 'classnames';
import * as tf from '@tensorflow/tfjs';
import InputRange from 'react-input-range';

import 'react-input-range/lib/css/index.css';

const IMAGE_SIZE = 320;

const updateImage = (imgSrc, ctx, dataCallback) => {
  const img = new Image();
  img.src = imgSrc;
  img.onload = () => {
    img.height = IMAGE_SIZE;
    img.width = IMAGE_SIZE; //img.width * (IMAGE_HEIGHT / img.height)
    ctx.drawImage(img, 0, 0, img.width, img.height)
    const data = ctx.getImageData(0, 0, img.width, img.height)
    dataCallback(data); 
  }
};

const runStyleImage = async (styleModel, transformModel, contentImage, styleImage, styleRatio = 0.5) => {
  await tf.nextFrame();
  await tf.nextFrame();
  let bottleneck = await tf.tidy(() => {
    return styleModel.predict(tf.fromPixels(styleImage).toFloat().div(tf.scalar(255)).expandDims());
  })
  if (styleRatio !== 1.0) {
    await tf.nextFrame();
    const identityBottleneck = await tf.tidy(() => {
      return styleModel.predict(tf.fromPixels(contentImage).toFloat().div(tf.scalar(255)).expandDims());
    })
    const styleBottleneck = bottleneck;
    bottleneck = await tf.tidy(() => {
      const styleBottleneckScaled = styleBottleneck.mul(tf.scalar(styleRatio));
      const identityBottleneckScaled = identityBottleneck.mul(tf.scalar(1.0-styleRatio));
      return styleBottleneckScaled.addStrict(identityBottleneckScaled)
    })
    styleBottleneck.dispose();
    identityBottleneck.dispose();
  }
  await tf.nextFrame();
  const stylized = await tf.tidy(() => {
    return transformModel.predict([tf.fromPixels(contentImage).toFloat().div(tf.scalar(255)).expandDims(), bottleneck]).squeeze();
  })

  // await tf.toPixels(stylized, this.stylized);
  bottleneck.dispose();  // Might wanna keep this around
  // stylized.dispose();
  return stylized 
}

function StyleImage() {
  const inputCanvas = useRef(null);
  const styleCanvas = useRef(null);
  const resultCanvas = useRef(null);
  const webcam = useRef(null);

  const [styleModel, updateStyleModel] = useState(null);
  const [transformModel, updateTransformModel] = useState(null);
  const [imageData, updateImageData] = useState(null);
  const [styleData, updateStyleData] = useState(null);
  const [styleRatio, updateStyleRatio] = useState(0.5);
  const [modalOpen, updateModalOpen] = useState(false);
  const [loading, updateLoading] = useState(false);

  const openModal = useCallback(() => updateModalOpen(true));
  const closeModal = useCallback(() => updateModalOpen(false));

  const updateContentImage = useCallback((e) => {
    const ctx = inputCanvas.current.getContext("2d")

    const reader = new FileReader();
    reader.readAsDataURL(e.target.files[0]);
    reader.onload = (e2) => {
      updateImage(e2.target.result, ctx, updateImageData);
    }
  }, [inputCanvas])

  const captureContentImage = useCallback(() => {
    const image = webcam.current.getScreenshot();
    const ctx = inputCanvas.current.getContext("2d")

    updateImage(image, ctx, updateImageData);
    closeModal();
  });

  const updateStyleImage = useCallback((e) => {
    const ctx = styleCanvas.current.getContext("2d")

    const reader = new FileReader();
    reader.readAsDataURL(e.target.files[0]);
    reader.onload = (e2) => {
      updateImage(e2.target.result, ctx, updateStyleData)
    }
  }, [styleCanvas])

  const styleImage = useCallback(() => {
    updateLoading(true);
    runStyleImage(styleModel, transformModel, imageData, styleData, styleRatio).then((newImage) => {
      tf.browser.toPixels(newImage, resultCanvas.current).finally(() => newImage.dispose())
      updateLoading(false);
    })
  }, [styleModel, transformModel, imageData, styleData, styleRatio]);

  const convertAndUpdateStyleRatio = useCallback((value) => {
    updateStyleRatio(value / 100.);
  })

  const loadModels = useCallback(async () => {
    const loadedStyleModel = await tf.loadFrozenModel(
      '/models/style/tensorflowjs_model.pb',
      '/models/style/weights_manifest.json'
    )
    updateStyleModel(loadedStyleModel);

    const loadedTransformModel = await tf.loadFrozenModel(
      '/models/separableTransform/tensorflowjs_model.pb',
      '/models/separableTransform/weights_manifest.json'
    )
    updateTransformModel(loadedTransformModel);
  }, []);

  useEffect(() => {
    loadModels();
  }, [])

  return (
    <div className="container">
      <div className="columns">
        <div className="column is-half">
          <h1>Content Image</h1>
          <button type="button" className="button is-info" onClick={openModal}>Webcam</button>
          <div className="jbtn-file">
            <div className="button is-primary ">Upload</div>
            <input type="file" onChange={updateContentImage} />
          </div>
          
          <canvas ref={inputCanvas} width={IMAGE_SIZE} height={IMAGE_SIZE} />
        </div>
        <div className="column is-half">
          <h2>Style Image</h2>
          <div className="jbtn-file">
            <div className="button is-primary ">Upload</div>
            <input type="file" onChange={updateStyleImage} />
          </div>
          
          <canvas ref={styleCanvas} width={IMAGE_SIZE} height={IMAGE_SIZE} />
        </div>
      </div>
      
      
      <div className="columns">
        <div className="column is-half is-offset-one-quarter">
          <h2>Style Ratio</h2>
          <div style={{ margin: '20px' }}>
            <InputRange
              maxValue={100}
              value={styleRatio * 100}
              onChange={convertAndUpdateStyleRatio}
            />
          </div>
          {imageData && styleData && 
            <button 
              className={classnames('button', 'is-primary', { 'is-loading': loading })} 
              onClick={styleImage}
            >Apply Styles</button>}
          <h2>Result Image</h2>
          <canvas ref={resultCanvas} width={IMAGE_SIZE} height={IMAGE_SIZE} />
        </div>
      </div>
      
      
      <Modal isOpen={modalOpen}>
        <div class="container">
          <Webcam  
            audio={false}
            height={IMAGE_SIZE}
            width={IMAGE_SIZE}
            ref={webcam}
            screenshotFormat="image/jpeg"
            videoConstraints={{ facingMode: "environment" }}
          />
        </div>
        <div class="container">
          <button className="button is-info" onClick={captureContentImage}>Capture</button>
          <button className="button is-danger" onClick={closeModal}>Close</button>
        </div>
      </Modal>
    </div>
  )
}

export default StyleImage;
