import { useCallback, useLayoutEffect, useMemo, useRef, useState } from 'react'

type Args = {
  currentSlide: number
  totalSlides: number
  transitionClasses: string
  onCurrentSlideChange: (currentSlide: number) => void
}

const TRANSITION_CLASSES = 'duration-500'
const TRANSLATE_THRESHOLD = 50

export function useDragCustomerStoriesSlider(args: Args) {
  const { onCurrentSlideChange, transitionClasses, currentSlide } = args

  const sliderTrackRef = useRef<HTMLElement>(null)
  const pressingRef = useRef(false)
  const slideTrackerMovedRef = useRef(false)
  const initialClientXRef = useRef<number | undefined>()
  const initialTranslateAxisRef = useRef(0)
  const [userPressing, setUserPressing] = useState(pressingRef.current)
  const translateAxisRef = useRef(initialTranslateAxisRef.current)
  const [prevCurrentSlide, setPrevCurrentSlide] = useState(currentSlide)

  useLayoutEffect(
    function translateSliderTrack() {
      if (currentSlide !== prevCurrentSlide) {
        setPrevCurrentSlide(currentSlide)
      }

      if (!sliderTrackRef.current) {
        return
      }

      const el = sliderTrackRef.current

      el.style.transform = `translate3d(-${
        (el.children[currentSlide] as HTMLElement).offsetLeft
      }px, 0, 0)`
    },
    [currentSlide, prevCurrentSlide],
  )

  const onMouseTouchDown = useCallback(
    (e: React.MouseEvent | React.TouchEvent) => {
      const clientX = getClientX(e)

      window.requestAnimationFrame(() => {
        setUpToMove({
          clientX,
          pressingRef,
          initialClientXRef,
          translateAxisRef,
          sliderTrackRef,
          setUserPressing,
          initialTranslateAxisRef,
          transitionClasses,
        })
      })
    },
    [transitionClasses],
  )

  const onMouseTouchMove = useCallback(
    (e: React.MouseEvent | React.TouchEvent) => {
      const clientX = getClientX(e)

      window.requestAnimationFrame(() => {
        if (
          !pressingRef.current ||
          !sliderTrackRef.current ||
          typeof initialClientXRef.current !== 'number'
        ) {
          return
        }

        slideTrackerMovedRef.current = true
        const el = sliderTrackRef.current

        translateAxisRef.current =
          initialTranslateAxisRef.current + clientX - initialClientXRef.current

        el.style.transform = `translate3d(${translateAxisRef.current}px, 0, 0)`
      })
    },
    [],
  )

  const onMouseTouchUp = useCallback(() => {
    if (!pressingRef.current) {
      return
    }

    pressingRef.current = false
    setUserPressing(false)

    window.requestAnimationFrame(() => {
      if (!sliderTrackRef.current) {
        return
      }

      const el = sliderTrackRef.current

      el.classList.add(...transitionClasses.split(' '))

      /**
       * We calculate the accumulated offsets to support different
       * sizes from slides. The current implementation doesn't
       * have different sizes between slides, but we can abstract
       * this logic into something that supports it for any slide
       * type.
       */
      const accumulatedOffsets = calculateAccumulatedOffset(el)

      const newActiveSlide = getNewActiveSlide({
        currentSlide,
        translateAxisRef,
        initialTranslateAxisRef,
        accumulatedOffsets,
      })

      onCurrentSlideChange(newActiveSlide)

      resetAndTranslate({
        pressingRef,
        initialClientXRef,
        translateAxisRef,
        newActiveSlide,
        sliderTrackRef,
        initialTranslateAxisRef,
        accumulatedOffsets,
      })
    })
  }, [currentSlide, onCurrentSlideChange, transitionClasses])

  const onItemClick = useCallback((e: React.MouseEvent) => {
    if (!slideTrackerMovedRef.current) {
      return
    }

    e.preventDefault()
    slideTrackerMovedRef.current = false
  }, [])

  return useMemo(
    () => ({
      onMouseTouchUp,
      onMouseTouchMove,
      onMouseTouchDown,
      onItemClick,
      userPressing,
      slideTrackerProps: {
        ref: sliderTrackRef,
        onMouseUp: onMouseTouchUp,
        onMouseDown: onMouseTouchDown,
        onMouseMove: onMouseTouchMove,
        onTouchStart: onMouseTouchDown,
        onTouchMove: onMouseTouchMove,
        onTouchEnd: onMouseTouchUp,
      },
    }),
    [
      onItemClick,
      onMouseTouchDown,
      onMouseTouchMove,
      onMouseTouchUp,
      userPressing,
    ],
  )
}

type GetNewActiveSlideArgs = {
  currentSlide: number
  accumulatedOffsets: number[]
  translateAxisRef: React.MutableRefObject<number>
  initialTranslateAxisRef: React.MutableRefObject<number>
}

function getNewActiveSlide(args: GetNewActiveSlideArgs) {
  const {
    currentSlide,
    initialTranslateAxisRef,
    accumulatedOffsets,
    translateAxisRef,
  } = args

  const xDiff = getTranslateXDiff({
    translateAxisRef,
    initialTranslateAxisRef,
  })

  if (xDiff < TRANSLATE_THRESHOLD) {
    return currentSlide
  }

  if (translateAxisRef.current > 0) {
    return 0
  }

  let newActiveSlide = currentSlide

  const moveDirection =
    translateAxisRef.current < initialTranslateAxisRef.current
      ? 'forward'
      : 'backwards'

  for (let i = 1; i < accumulatedOffsets.length - 1; i++) {
    const offset = accumulatedOffsets[i]

    if (offset <= Math.abs(translateAxisRef.current)) {
      continue
    }

    // Moving forward
    if (moveDirection === 'forward') {
      newActiveSlide = i
      // Moving backwards
    } else {
      newActiveSlide = i - 1
    }

    break
  }

  return newActiveSlide
}

type ResetAndTranslateArgs = {
  newActiveSlide: number
  accumulatedOffsets: number[]
  pressingRef: React.MutableRefObject<boolean>
  sliderTrackRef: React.RefObject<HTMLElement>
  translateAxisRef: React.MutableRefObject<number>
  initialClientXRef: React.MutableRefObject<number | undefined>
  initialTranslateAxisRef: React.MutableRefObject<number>
}

/**
 * Resets the variables to their "idle" values and translates the slider track.
 */
function resetAndTranslate(args: ResetAndTranslateArgs) {
  const {
    pressingRef,
    initialClientXRef,
    translateAxisRef,
    newActiveSlide,
    sliderTrackRef,
    initialTranslateAxisRef,
    accumulatedOffsets,
  } = args

  if (!sliderTrackRef.current) {
    if (process.env.NODE_ENV === 'development') {
      console.error('No ref defined in the `sliderTrackRef`')
    }

    return
  }

  const el = sliderTrackRef.current

  pressingRef.current = false
  initialClientXRef.current = undefined
  translateAxisRef.current = initialTranslateAxisRef.current
  el.classList.add(TRANSITION_CLASSES)

  el.style.transform = `translate3d(-${accumulatedOffsets[newActiveSlide]}px, 0, 0)`
}

type GetTranslateXDiffArgs = {
  translateAxisRef: React.MutableRefObject<number>
  initialTranslateAxisRef: React.MutableRefObject<number>
}

function getTranslateXDiff(args: GetTranslateXDiffArgs) {
  const { initialTranslateAxisRef, translateAxisRef } = args

  return Math.abs(
    Math.abs(initialTranslateAxisRef.current) -
      Math.abs(translateAxisRef.current),
  )
}

/**
 * Calculates the accumulated offset of each slide.
 */
function calculateAccumulatedOffset(el: HTMLElement): number[] {
  const offsets: number[] = [0]
  let totalOffset = 0

  for (let i = 1; i <= el.children.length; i++) {
    const child = el.children[i - 1] as HTMLElement

    totalOffset += child.offsetWidth

    offsets[i] = totalOffset
  }

  return offsets
}

type SetUpToMoveArgs = {
  clientX: number
  transitionClasses: string
  setUserPressing: (pressing: boolean) => void
  sliderTrackRef: React.RefObject<HTMLElement>
  pressingRef: React.MutableRefObject<boolean>
  translateAxisRef: React.MutableRefObject<number>
  initialTranslateAxisRef: React.MutableRefObject<number>
  initialClientXRef: React.MutableRefObject<number | undefined>
}

/**
 * Initializes the variables to track the mouse movement and drag the slider track.
 */
function setUpToMove(args: SetUpToMoveArgs) {
  const {
    clientX,
    pressingRef,
    initialClientXRef,
    translateAxisRef,
    sliderTrackRef,
    setUserPressing,
    initialTranslateAxisRef,
    transitionClasses,
  } = args

  if (!sliderTrackRef.current) {
    return
  }

  const el = sliderTrackRef.current

  const { x } = getTranslateCoords(el.style.transform)

  setUserPressing(true)
  pressingRef.current = true
  initialClientXRef.current = clientX
  initialTranslateAxisRef.current = x
  translateAxisRef.current = initialTranslateAxisRef.current

  el.classList.remove(...transitionClasses.split(' '))
}

/**
 * Gets the translate coordinates from the `transform` CSS property.
 */
function getTranslateCoords(transform: string | undefined) {
  if (!transform) {
    return { x: 0, y: 0, z: 0 }
  }

  const values = transform.split('(')[1].split(')')[0].split(',')

  const { x, y, z } = {
    x: parseInt(values[0]?.trim(), 10),
    y: parseInt(values[1]?.trim(), 10),
    z: parseInt(values[2]?.trim(), 10),
  }

  return {
    x: !Number.isNaN(x) ? x : 0,
    y: !Number.isNaN(y) ? y : 0,
    z: !Number.isNaN(z) ? z : 0,
  }
}

/**
 * Gets the `clientX` value from a mouse or touch event.
 *
 * @see https://developer.mozilla.org/en-US/docs/Web/API/Touch/clientX
 * @see https://developer.mozilla.org/en-US/docs/Web/API/MouseEvent/clientX
 */
function getClientX(e: React.MouseEvent | React.TouchEvent) {
  if ('clientX' in e) {
    return e.clientX
  }

  return e.touches[0].clientX
}
