#ifndef ESBMC_VARIANT
#define ESBMC_VARIANT

#include "cstddef"
#include "type_traits"
#include "utility"

// Minimal C++17 <variant> operational model.  Supports up to 4 alternative
// types via a flat discriminator + per-alternative member, which is wider
// than a real tagged union but observably equivalent for verification and
// avoids the constexpr/APValue::Union path that ESBMC's converter does
// not yet handle.  Per [variant] in N4861/P0088.  Out-of-scope: visit()
// returning non-void, monostate, in_place_type construction past the
// fourth alternative, and emplace().

namespace std
{

inline constexpr size_t variant_npos = static_cast<size_t>(-1);

namespace __variant_detail
{

template <size_t I, class... Ts>
struct __nth;
template <class T0, class... Ts>
struct __nth<0, T0, Ts...>
{
  using type = T0;
};
template <size_t I, class T0, class... Ts>
struct __nth<I, T0, Ts...>
{
  using type = typename __nth<I - 1, Ts...>::type;
};

template <class T, class... Ts>
struct __index_of;
template <class T, class T0, class... Ts>
struct __index_of<T, T0, Ts...>
{
  static constexpr size_t value =
    is_same<T, T0>::value ? 0 : 1 + __index_of<T, Ts...>::value;
};
template <class T>
struct __index_of<T>
{
  static constexpr size_t value = 0;
};

// Pad-out helper: for fewer than 4 types we instantiate the storage
// with `char` placeholders so the variant template body can refer to
// __nth<0..3> unconditionally.  These slots are never accessed at
// runtime (the discriminator gates access) so the choice of placeholder
// is immaterial for verification.
template <size_t I, class... Ts>
struct __at_or_char
{
  using type = char;
};
template <class T0, class... Ts>
struct __at_or_char<0, T0, Ts...>
{
  using type = T0;
};
template <class T0, class T1, class... Ts>
struct __at_or_char<1, T0, T1, Ts...>
{
  using type = T1;
};
template <class T0, class T1, class T2, class... Ts>
struct __at_or_char<2, T0, T1, T2, Ts...>
{
  using type = T2;
};
template <class T0, class T1, class T2, class T3, class... Ts>
struct __at_or_char<3, T0, T1, T2, T3, Ts...>
{
  using type = T3;
};

} // namespace __variant_detail

template <class... Types>
class variant
{
public:
  // Storage.  Discriminator + four typed slots.  Wasteful at runtime but
  // observably equivalent for verification.
  size_t __idx_;
  typename __variant_detail::__at_or_char<0, Types...>::type __s0_;
  typename __variant_detail::__at_or_char<1, Types...>::type __s1_;
  typename __variant_detail::__at_or_char<2, Types...>::type __s2_;
  typename __variant_detail::__at_or_char<3, Types...>::type __s3_;

  static_assert(
    sizeof...(Types) <= 4,
    "this <variant> OM only supports up to 4 alternatives");

  constexpr variant() : __idx_(0), __s0_(), __s1_(), __s2_(), __s3_()
  {
  }

  template <class T>
  constexpr variant(T &&v) : __idx_(0), __s0_(), __s1_(), __s2_(), __s3_()
  {
    __assign(std::forward<T>(v));
  }

  template <class T>
  variant &operator=(T &&v)
  {
    __assign(std::forward<T>(v));
    return *this;
  }

  constexpr size_t index() const noexcept
  {
    return __idx_;
  }
  constexpr bool valueless_by_exception() const noexcept
  {
    return __idx_ == variant_npos;
  }

private:
  // Pick the slot for the alternative whose decayed type matches T.
  // We could use the __index_of helper, but we explicitly hand-unroll
  // for slots 0..3 to avoid relying on constexpr-if at the OM level.
  template <class T>
  void __assign(T &&v)
  {
    using D = typename decay<T>::type;
    using S0 = typename __variant_detail::__at_or_char<0, Types...>::type;
    using S1 = typename __variant_detail::__at_or_char<1, Types...>::type;
    using S2 = typename __variant_detail::__at_or_char<2, Types...>::type;
    using S3 = typename __variant_detail::__at_or_char<3, Types...>::type;

    if constexpr (is_same<D, S0>::value)
    {
      __idx_ = 0;
      __s0_ = std::forward<T>(v);
    }
    else if constexpr (is_same<D, S1>::value)
    {
      __idx_ = 1;
      __s1_ = std::forward<T>(v);
    }
    else if constexpr (is_same<D, S2>::value)
    {
      __idx_ = 2;
      __s2_ = std::forward<T>(v);
    }
    else if constexpr (is_same<D, S3>::value)
    {
      __idx_ = 3;
      __s3_ = std::forward<T>(v);
    }
  }
};

template <class T, class... Types>
constexpr bool holds_alternative(const variant<Types...> &v) noexcept
{
  return v.index() == __variant_detail::__index_of<T, Types...>::value;
}

template <size_t I, class... Types>
constexpr typename __variant_detail::__nth<I, Types...>::type &
get(variant<Types...> &v)
{
  if constexpr (I == 0)
    return v.__s0_;
  else if constexpr (I == 1)
    return v.__s1_;
  else if constexpr (I == 2)
    return v.__s2_;
  else
    return v.__s3_;
}

template <size_t I, class... Types>
constexpr const typename __variant_detail::__nth<I, Types...>::type &
get(const variant<Types...> &v)
{
  if constexpr (I == 0)
    return v.__s0_;
  else if constexpr (I == 1)
    return v.__s1_;
  else if constexpr (I == 2)
    return v.__s2_;
  else
    return v.__s3_;
}

template <class T, class... Types>
constexpr T &get(variant<Types...> &v)
{
  return get<__variant_detail::__index_of<T, Types...>::value>(v);
}

template <class T, class... Types>
constexpr const T &get(const variant<Types...> &v)
{
  return get<__variant_detail::__index_of<T, Types...>::value>(v);
}

template <class T, class... Types>
constexpr typename add_pointer<T>::type
get_if(variant<Types...> *v) noexcept
{
  if (!v ||
      v->index() != __variant_detail::__index_of<T, Types...>::value)
    return nullptr;
  constexpr size_t I = __variant_detail::__index_of<T, Types...>::value;
  if constexpr (I == 0)
    return &v->__s0_;
  else if constexpr (I == 1)
    return &v->__s1_;
  else if constexpr (I == 2)
    return &v->__s2_;
  else
    return &v->__s3_;
}

template <class T, class... Types>
constexpr typename add_pointer<const T>::type
get_if(const variant<Types...> *v) noexcept
{
  if (!v ||
      v->index() != __variant_detail::__index_of<T, Types...>::value)
    return nullptr;
  constexpr size_t I = __variant_detail::__index_of<T, Types...>::value;
  if constexpr (I == 0)
    return &v->__s0_;
  else if constexpr (I == 1)
    return &v->__s1_;
  else if constexpr (I == 2)
    return &v->__s2_;
  else
    return &v->__s3_;
}

} // namespace std

#endif
