Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 111 additions & 39 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,55 +199,113 @@ template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;

// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
// merge sequence - optimized to avoid recursive instantiation
//
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why:
//
// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each
// element can be computed independently: output[i] = f(i)
//
// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths.
// To compute output[i], we need to know:
// 1. Which input sequence contains this index
// 2. The offset within that sequence
// This requires computing cumulative sequence lengths, which requires recursion/iteration.
//
// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth:
// - Base cases handle 1-4 sequences directly (O(1) for common cases)
// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...)
// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences
//
// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to
// linear dependency chain, so binary tree is superior.
//
namespace detail {

// Helper to concatenate multiple sequences in one step using fold expression
template <typename... Seqs>
struct sequence_merge_impl;

// Base case: single sequence
template <index_t... Is>
struct sequence_merge_impl<Sequence<Is...>>
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
using type = Sequence<Is...>;
};

// Two sequences: direct concatenation
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};

template <typename Seq>
struct sequence_merge<Seq>
// Three sequences: direct concatenation (avoids one level of recursion)
template <index_t... Xs, index_t... Ys, index_t... Zs>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
{
using type = Seq;
using type = Sequence<Xs..., Ys..., Zs...>;
};

// generate sequence
template <index_t NSize, typename F>
struct sequence_gen
// Four sequences: direct concatenation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these specializations. It will be interesting to get a survey of the code to see how often the specializations are used and if these four smallest cases are the most impactful ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using the build traces to drive the optimizations. Maybe removing the unused code is one other aspect which could help with parsing times

template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = Sequence<As..., Bs..., Cs..., Ds...>;
};

using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
// General case: binary tree reduction (O(log N) depth instead of O(N))
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
{
// Merge pairs first, then recurse
using left = typename sequence_merge_impl<S1, S2>::type;
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
using type = typename sequence_merge_impl<left, right>::type;
};

template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
} // namespace detail

template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
template <typename... Seqs>
struct sequence_merge
{
using type = typename detail::sequence_merge_impl<Seqs...>::type;
};

template <>
struct sequence_merge<>
{
using type = Sequence<>;
};

// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
namespace detail {

// Helper that applies functor F to indices and produces a Sequence
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
// ..., N-1>
template <typename T, T... Is>
struct sequence_gen_helper
{
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
template <typename F>
using apply = Sequence<F{}(Number<Is>{})...>;
};

} // namespace detail

using type = typename sequence_gen_impl<0, NSize, F>::type;
template <index_t NSize, typename F>
struct sequence_gen
{
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
};

template <typename F>
struct sequence_gen<0, F>
{
using type = Sequence<>;
};

// arithmetic sequence
Expand Down Expand Up @@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
};

// uniform sequence
// uniform sequence - optimized using __make_integer_seq
namespace detail {

template <typename T, T... Is>
struct uniform_sequence_helper
{
// Apply a constant value to all indices via pack expansion
template <index_t Value>
using apply = Sequence<((void)Is, Value)...>;
};

} // namespace detail

template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
template apply<I>;
};

using type = typename sequence_gen<NSize, F>::type;
template <index_t I>
struct uniform_sequence_gen<0, I>
{
using type = Sequence<>;
};

// reverse inclusive scan (with init) sequence
Expand Down
1 change: 1 addition & 0 deletions include/ck/utility/statically_indexed_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
using type = Tuple<Xs..., Ys...>;
};

// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
{
Expand Down
134 changes: 134 additions & 0 deletions test/util/unit_sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize)
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceGen, UniformSequenceSingleElement)
{
using Result = typename uniform_sequence_gen<1, 99>::type;
using Expected = Sequence<99>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceGen, UniformSequenceDifferentValues)
{
using Result1 = typename uniform_sequence_gen<3, 0>::type;
using Expected1 = Sequence<0, 0, 0>;
EXPECT_TRUE((is_same<Result1, Expected1>::value));

using Result2 = typename uniform_sequence_gen<4, -5>::type;
using Expected2 = Sequence<-5, -5, -5, -5>;
EXPECT_TRUE((is_same<Result2, Expected2>::value));
}

TEST(SequenceGen, UniformSequenceLargeSize)
{
// Test with larger size to verify __make_integer_seq implementation
using Result = typename uniform_sequence_gen<16, 7>::type;
using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

// Test make_index_sequence
TEST(SequenceGen, MakeIndexSequence)
{
Expand All @@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero)
EXPECT_TRUE((is_same<Result, Expected>::value));
}

// Test sequence_gen with custom functors
TEST(SequenceGen, SequenceGenWithDoubleFunctor)
{
struct DoubleFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; }
};
using Result = typename sequence_gen<5, DoubleFunctor>::type;
using Expected = Sequence<0, 2, 4, 6, 8>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceGen, SequenceGenWithSquareFunctor)
{
struct SquareFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; }
};
using Result = typename sequence_gen<5, SquareFunctor>::type;
using Expected = Sequence<0, 1, 4, 9, 16>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceGen, SequenceGenZeroSize)
{
struct IdentityFunctor
{
__host__ __device__ constexpr index_t operator()(index_t i) const { return i; }
};
using Result = typename sequence_gen<0, IdentityFunctor>::type;
using Expected = Sequence<>;
EXPECT_TRUE((is_same<Result, Expected>::value));
// Also verify non-zero size works with identity
using Result5 = typename sequence_gen<5, IdentityFunctor>::type;
EXPECT_TRUE((is_same<Result5, Sequence<0, 1, 2, 3, 4>>::value));
}

TEST(SequenceGen, SequenceGenSingleElement)
{
struct ConstantFunctor
{
__host__ __device__ constexpr index_t operator()(index_t) const { return 42; }
};
using Result = typename sequence_gen<1, ConstantFunctor>::type;
using Expected = Sequence<42>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

// Test sequence_merge
TEST(SequenceMerge, MergeTwoSequences)
{
Expand Down Expand Up @@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence)
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceMerge, MergeFourSequences)
{
// Test the 4-sequence specialization
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2, 3>;
using Seq3 = Sequence<4, 5, 6>;
using Seq4 = Sequence<7, 8>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4>::type;
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceMerge, MergeFiveSequences)
{
// Test the binary tree reduction path (5+ sequences)
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2>;
using Seq3 = Sequence<3>;
using Seq4 = Sequence<4>;
using Seq5 = Sequence<5>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5>::type;
using Expected = Sequence<1, 2, 3, 4, 5>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceMerge, MergeManySequences)
{
// Test with many sequences to stress the binary tree reduction
using Seq1 = Sequence<1>;
using Seq2 = Sequence<2>;
using Seq3 = Sequence<3, 4>;
using Seq4 = Sequence<5>;
using Seq5 = Sequence<6, 7>;
using Seq6 = Sequence<8>;
using Seq7 = Sequence<9, 10>;
using Seq8 = Sequence<11, 12>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3, Seq4, Seq5, Seq6, Seq7, Seq8>::type;
using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceMerge, MergeEmptySequences)
{
// Test merging empty sequences
using Seq1 = Sequence<>;
using Seq2 = Sequence<1, 2>;
using Seq3 = Sequence<>;
using Result = typename sequence_merge<Seq1, Seq2, Seq3>::type;
using Expected = Sequence<1, 2>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

TEST(SequenceMerge, MergeZeroSequences)
{
// Test the empty specialization
using Result = typename sequence_merge<>::type;
using Expected = Sequence<>;
EXPECT_TRUE((is_same<Result, Expected>::value));
}

// Test sequence_split
TEST(SequenceSplit, SplitInMiddle)
{
Expand Down
Loading