Mjolnir Core
Core functionality of the Mjolnir API
element_summation.h
Go to the documentation of this file.
1 
7 
8 #pragma once
9 
11 
12 namespace mjolnir::x86
13 {
16 
17 
29 template <FloatVectorRegister T_RegisterType>
30 [[nodiscard]] inline auto broadcast_element_sum(T_RegisterType src) noexcept -> T_RegisterType;
31 
32 
44 template <FloatVectorRegister T_RegisterType>
45 [[nodiscard]] inline auto element_sum(T_RegisterType src) noexcept -> ElementType<T_RegisterType>;
46 
47 
61 template <UST t_num_elements, FloatVectorRegister T_RegisterType>
62 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>;
63 
64 
65 // ---internal declarations -------------------------------------------------------------------------------------------
66 
68 namespace internal
69 {
70 template <UST t_num_elements, DoublePrecisionVectorRegister T_RegisterType>
71 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>;
72 
73 
74 template <UST t_num_elements, SinglePrecisionVectorRegister T_RegisterType>
75 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>;
76 } // namespace internal
78 
79 
81 } // namespace mjolnir::x86
82 
83 
84 // === DEFINITIONS ====================================================================================================
85 
86 
89 #include "mjolnir/core/x86/x86.h"
90 
91 namespace mjolnir::x86
92 {
93 // --------------------------------------------------------------------------------------------------------------------
94 
95 
96 template <FloatVectorRegister T_RegisterType>
97 [[nodiscard]] inline auto broadcast_element_sum(T_RegisterType src) noexcept -> T_RegisterType
98 {
99  if constexpr (is_single_precision<T_RegisterType>)
100  {
101  T_RegisterType sum = mm_add(src, permute<1, 0, 3, 2>(src));
102  sum = mm_add(sum, permute<2, 3, 0, 1>(sum));
103 
104  if constexpr (is_avx_register<T_RegisterType>)
105  sum = mm_add(sum, swap_lanes(sum));
106 
107  return sum;
108  }
109  else
110  {
111  T_RegisterType sum = mm_add(src, permute<1, 0>(src));
112 
113  if constexpr (is_avx_register<T_RegisterType>)
114  sum = mm_add(sum, swap_lanes(sum));
115 
116  return sum;
117  }
118 }
119 
120 
121 // --------------------------------------------------------------------------------------------------------------------
122 
123 template <FloatVectorRegister T_RegisterType>
124 [[nodiscard]] inline auto element_sum(T_RegisterType src) noexcept -> ElementType<T_RegisterType>
125 {
126  T_RegisterType sum = broadcast_element_sum(src);
127  return mm_cvt_float(sum);
128 }
129 
130 
131 // --------------------------------------------------------------------------------------------------------------------
132 
133 template <UST t_num_elements, FloatVectorRegister T_RegisterType>
134 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>
135 {
136  constexpr UST n_e = num_elements<T_RegisterType>;
137 
138  static_assert(t_num_elements > 0, "`t_num_elements` must be larger than 0.");
139  static_assert(t_num_elements <= n_e, "`t_num_elements` must be less or equal to the number of register elements.");
140 
141  return internal::element_sum_first_n<t_num_elements, T_RegisterType>(src);
142 }
143 
144 
145 // --- internal definitions -------------------------------------------------------------------------------------------
146 
148 namespace internal
149 {
150 // --------------------------------------------------------------------------------------------------------------------
151 
152 template <UST t_num_elements, DoublePrecisionVectorRegister T_RegisterType>
153 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>
154 {
155  if constexpr (t_num_elements == 1)
156  return mm_cvt_float(src);
157  else if constexpr (t_num_elements == num_elements<T_RegisterType>)
158  return element_sum(src);
159  else
160  {
161  T_RegisterType sum = mm_add(src, permute<1, 0>(src));
162 
163  if constexpr (t_num_elements == 3)
164  sum = mm_add(sum, swap_lanes(src));
165 
166  return mm_cvt_float(sum);
167  }
168 }
169 
170 
171 // --------------------------------------------------------------------------------------------------------------------
172 
173 template <UST t_num_elements, SinglePrecisionVectorRegister T_RegisterType>
174 [[nodiscard]] inline auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType<T_RegisterType>
175 {
176  if constexpr (t_num_elements == 1)
177  return mm_cvt_float(src);
178  else if constexpr (t_num_elements == num_elements<T_RegisterType>)
179  return element_sum(src);
180  else if constexpr (t_num_elements == 7) // NOLINT(readability-magic-numbers)
181  {
182  auto zero = mm_setzero<__m256>();
183  return element_sum(blend_at<t_num_elements>(src, zero));
184  }
185  else
186  {
187  T_RegisterType sum = mm_add(src, permute<1, 0, 3, 2>(src));
188 
189  if constexpr (t_num_elements == 3)
190  sum = mm_add(sum, broadcast<2>(src));
191 
192  if constexpr (t_num_elements == 4 || t_num_elements == 5) // NOLINT(readability-magic-numbers)
193  {
194  sum = mm_add(sum, permute<2, 3, 0, 1>(sum));
195 
196  if constexpr (t_num_elements == 5) // NOLINT(readability-magic-numbers)
197  sum = mm_add(sum, swap_lanes(src));
198  }
199 
200  if constexpr (t_num_elements == 6) // NOLINT(readability-magic-numbers, readability-misleading-indentation)
201  {
202  __m256 tmp = sum;
203  sum = mm_add(sum, permute<2, 3, 0, 1>(sum));
204  sum = mm_add(sum, swap_lanes(tmp));
205  }
206 
207  return mm_cvt_float(sum); // NOLINT(readability-misleading-indentation)
208  }
209 }
211 } // namespace internal
212 
213 
214 } // namespace mjolnir::x86
std::size_t UST
Unsigned integer type that is returned by sizeof operations.
Definition: fundamental_types.h:29
auto mm_cvt_float(T_RegisterType src) -> ElementType< T_RegisterType >
Return the first element of src.
Definition: intrinsics.h:739
concept SinglePrecisionVectorRegister
Concept for a x86 vector register that has single precision elements.
Definition: definitions.h:66
typename std::conditional_t< is_any_of< T_RegisterType, __m128d, __m256d >(), F64, F32 > ElementType
The element type of an x86 vector register that is based on floating-point types.
Definition: definitions.h:212
auto broadcast(T_RegisterType src) noexcept -> T_RegisterType
Broadcast a register element per lane selected by t_index.
Definition: permutation.h:562
auto broadcast_element_sum(T_RegisterType src) noexcept -> T_RegisterType
Calculate the sum of all elements of src, broadcast it into a new register and return the result.
Definition: element_summation.h:97
auto element_sum_first_n(T_RegisterType src) noexcept -> ElementType< T_RegisterType >
Return the sum of the first t_num_elements elements from src.
Definition: element_summation.h:134
constexpr UST num_elements
Number of register elements.
Definition: definitions.h:257
auto mm_add(T_RegisterType lhs, T_RegisterType rhs) noexcept -> T_RegisterType
Perform an element-wise addition of lhs and rhs and return the result.
Definition: intrinsics.h:544
auto element_sum(T_RegisterType src) noexcept -> ElementType< T_RegisterType >
Return the sum of all elements from src.
Definition: element_summation.h:124
auto swap_lanes(T_RegisterType src) noexcept -> T_RegisterType
Swap the lanes of an AVX register and return the result.
Definition: permutation.h:911
Contains generalized/template versions of the x86 intrinsics.
Contains functions to permute and blend the elements of vector registers.
Contains x86 vectorization specific constants, concepts and definitions.
This header includes the correct x86 header depending on the operation system.