3 Коммиты 052debf397 ... ac6b110d66

Автор SHA1 Сообщение Дата
  lloda ac6b110d66 Give up on RA_STATIC_UNROLL 5 месяцев назад
  lloda 7e7b53138b Omit rt match checks when only one arg has positive rank 5 месяцев назад
  lloda df4173c699 Guard against uses of ra::len in scalar conversion ops 5 месяцев назад
7 измененных файлов с 68 добавлено и 96 удалено
  1. 2 4
      bench/SConstruct
  2. 3 6
      config/ra.py
  3. 45 39
      ra/expr.hh
  4. 6 43
      ra/ply.hh
  5. 1 4
      test/SConstruct
  6. 5 0
      test/checks.cc
  7. 6 0
      test/len.cc

+ 2 - 4
bench/SConstruct

@@ -68,11 +68,9 @@ tester = ra.to_test_ra(env, variant_dir)
                'bench-pack', 'bench-from',
                'bench-stencil1', 'bench-stencil2', 'bench-stencil3',
                'bench-optimize', 'bench-tensorindex',
-               'bench-iterator', 'bench-at'
+               'bench-iterator', 'bench-at',
+               'bench-dot'
            ]]
 
-tester('bench-dot', target='bench-dot-no-su', cppdefines={'RA_STATIC_UNROLL': '0'})
-tester('bench-dot', target='bench-dot-su', cppdefines={'RA_STATIC_UNROLL': '1'})
-
 if not top['skip_summary']:
     atexit.register(lambda: ra.print_summary(GetBuildFailures, 'ra/bench'))

+ 3 - 6
config/ra.py

@@ -130,12 +130,9 @@ def to_source_from_noweb(env, targets, source):
     return [env.Notangle(target, remove_ext(main) + '.nw') for target in targets]
 
 def to_test_ra(env_, variant_dir):
-    def f(source, target='', cxxflags=[], cppdefines=[]):
-        if len(cxxflags)==0 or len(cppdefines)==0:
-            env = env_
-        else:
-            env = env_.Clone()
-            env.Append(CXXFLAGS=cxxflags + ['-U' + k for k in cppdefines.keys()], CPPDEFINES=cppdefines)
+    def f(source, target='', cxxflags=[], cppdefines={}):
+        env = env_.Clone()
+        env.Append(CXXFLAGS=cxxflags + ['-U' + k for k in cppdefines.keys()], CPPDEFINES=cppdefines)
         if len(target)==0:
             target = source
         obj = env.Object(target, [source + '.cc'])

+ 45 - 39
ra/expr.hh

@@ -75,6 +75,36 @@ constexpr bool inside(dim_t i, dim_t b) { return 0<=i && i<b; }
 // terminal types
 // --------------------
 
+constexpr struct Len
+{
+    consteval static rank_t rank() { return 0; }
+    constexpr static dim_t len_s(int k) { std::abort(); }
+    constexpr static dim_t len(int k) { std::abort(); }
+    constexpr static dim_t step(int k) { std::abort(); }
+    constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
+    constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
+    constexpr dim_t operator*() const { std::abort(); }
+    constexpr static int save() { std::abort(); }
+    constexpr static void load(int) { std::abort(); }
+    constexpr static void mov(dim_t d) { std::abort(); }
+} len;
+
+template <> constexpr bool is_special_def<Len> = true;  // protect exprs with Len from reduction.
+template <class E> struct WLen;                         // defined in ply.hh.
+template <class E> concept has_len = requires(int ln, E && e) { WLen<std::decay_t<E>>::f(ln, RA_FWD(e)); };
+
+template <class Ln, class E>
+constexpr decltype(auto)
+wlen(Ln ln, E && e)
+{
+    static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
+    if constexpr (has_len<E>) {
+        return WLen<std::decay_t<E>>::f(ln, RA_FWD(e));
+    } else {
+        return RA_FWD(e);
+    }
+}
+
 // Rank-0 IteratorConcept. Can be used on foreign objects, or as alternative to the rank conjunction.
 // We still want f(scalar(C)) to be f(C) and not map(f, C), this is controlled by tomap/toreduce.
 template <class C>
@@ -253,25 +283,6 @@ inside(is_iota auto const & i, dim_t l)
     return (inside(i.i, l) && inside(i.i+(i.n-1)*i.s, l)) || (0==i.n /* don't bother */);
 }
 
-constexpr struct Len
-{
-    consteval static rank_t rank() { return 0; }
-    constexpr static dim_t len_s(int k) { std::abort(); }
-    constexpr static dim_t len(int k) { std::abort(); }
-    constexpr static dim_t step(int k) { std::abort(); }
-    constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
-    constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
-    constexpr dim_t operator*() const { std::abort(); }
-    constexpr static int save() { std::abort(); }
-    constexpr static void load(int) { std::abort(); }
-    constexpr static void mov(dim_t d) { std::abort(); }
-} len;
-
-// protect exprs with Len from reduction.
-template <> constexpr bool is_special_def<Len> = true;
-template <class E> struct WLen {};
-template <class E> concept has_len = requires(int ln, E && e) { WLen<std::decay_t<E>>::f(ln, RA_FWD(e)); };
-
 
 // --------------
 // making Iterators
@@ -314,6 +325,17 @@ start(T & t) { return t; }
 constexpr decltype(auto)
 start(is_iterator auto && t) { return RA_FWD(t); }
 
+// a form of ply() for conversion ops
+template <class E>
+decltype(auto) to_scalar(E && e)
+{
+    static_assert(!has_len<E>, "len outside subscript context.");
+    if constexpr (1!=size_s<E>()) {
+        RA_CHECK(1==size(e), "Bad scalar conversion from shape [", ra::noshape, ra::shape(e), "].");
+    }
+    return *e;
+}
+
 
 // --------------------
 // prefix match
@@ -337,10 +359,10 @@ struct Match<checkp, std::tuple<P ...>, mp::int_list<I ...>>
     consteval static int
     check_s()
     {
-        if constexpr (sizeof...(P)<2) {
+        if constexpr (sizeof...(P)<2 || sizeof...(P)==1+(bool(0==ra::rank_s<P>()) + ...)) {
             return 2;
         } else if constexpr (ANY==rs) {
-            return 1; // FIXME can be tightened to 2 if all args are rank 0 save one
+            return 1;
         } else {
             bool tbc = false;
             for (int k=0; k<rs; ++k) {
@@ -661,15 +683,7 @@ struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std:
     RA_ASSIGNOPS_DEFAULT_SET
     constexpr decltype(auto) at(auto const & j) const { return std::invoke(op, std::get<I>(t).at(j) ...); }
     constexpr decltype(auto) operator*() const { return std::invoke(op, *std::get<I>(t) ...); }
-// needed for rs==ANY, which don't decay to scalar when used as operator arguments.
-    constexpr
-    operator decltype(std::invoke(op, *std::get<I>(t) ...)) () const
-    {
-        if constexpr (1!=size_s<Expr>()) {
-            RA_CHECK(1==size(*this), "Bad conversion to scalar from shape [", ra::noshape, ra::shape(*this), "].");
-        }
-        return *(*this);
-    }
+    constexpr operator decltype(std::invoke(op, *std::get<I>(t) ...)) () const { return to_scalar(*this); }
 };
 
 template <class Op, IteratorConcept ... P>
@@ -751,15 +765,7 @@ struct Pick<std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tup
     RA_ASSIGNOPS_DEFAULT_SET
     constexpr decltype(auto) at(auto const & j) const { return pick_at<0>(std::get<0>(t).at(j), t, j); }
     constexpr decltype(auto) operator*() const { return pick_star<0>(*std::get<0>(t), t); }
-// needed for rs==ANY, which don't decay to scalar when used as operator arguments.
-    constexpr
-    operator decltype(pick_star<0>(*std::get<0>(t), t)) () const
-    {
-        if constexpr (1!=size_s<Pick>()) {
-            RA_CHECK(1==size(*this), "Bad conversion to scalar from shape [", ra::noshape, ra::shape(*this), "].");
-        }
-        return *(*this);
-    }
+    constexpr operator decltype(pick_star<0>(*std::get<0>(t), t)) () const { return to_scalar(*this); }
 };
 
 template <IteratorConcept ... P>

+ 6 - 43
ra/ply.hh

@@ -34,18 +34,6 @@ template <class A> using ncvalue_t = std::remove_const_t<value_t<A>>;
 // replace Len in expr tree.
 // ---------------------
 
-template <class Ln, class E>
-constexpr decltype(auto)
-wlen(Ln ln, E && e)
-{
-    static_assert(std::is_integral_v<std::decay_t<Ln>> || is_constant<std::decay_t<Ln>>);
-    if constexpr (has_len<E>) {
-        return WLen<std::decay_t<E>>::f(ln, RA_FWD(e));
-    } else {
-        return RA_FWD(e);
-    }
-}
-
 template <>
 struct WLen<Len>
 {
@@ -220,11 +208,6 @@ subply(A & a, dim_t s, S const & ss0, Early & early)
     }
 }
 
-// possibly pessimize ply_fixed(). See bench-dot [ra43]
-#ifndef RA_STATIC_UNROLL
-#define RA_STATIC_UNROLL 0
-#endif
-
 template <IteratorConcept A, class Early = Nop>
 constexpr decltype(auto)
 ply_fixed(A && a, Early && early = Nop {})
@@ -241,36 +224,16 @@ ply_fixed(A && a, Early && early = Nop {})
             return;
         }
     } else {
-// static keep_step implies all else is static.
-        if constexpr (RA_STATIC_UNROLL && rank>1 && requires (dim_t st, rank_t z, rank_t j) { A::keep_step(st, z, j); }) {
-            constexpr auto ss0 = a.step(order[0]);
-// find outermost compact dim.
-            constexpr auto sj = [&order]
-            {
-                dim_t ss = A::len_s(order[0]);
-                int j = 1;
-                for (; j<rank && A::keep_step(ss, order[0], order[j]); ++j) {
-                    ss *= A::len_s(order[j]);
-                }
-                return std::make_tuple(ss, j);
-            } ();
-            if constexpr (requires {early.def;}) {
-                return (subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early)).value_or(early.def);
-            } else {
-                subply<order, rank-1, std::get<1>(sj)>(a, std::get<0>(sj), ss0, early);
-            }
-        } else {
 #pragma GCC diagnostic push // gcc 12.2 and 13.2 with RA_DO_CHECK=0 and -fno-sanitize=all
 #pragma GCC diagnostic warning "-Warray-bounds"
-            auto ss0 = a.step(order[0]); // gcc 14.1 with RA_DO_CHECK=0 and sanitizer on
+        auto ss0 = a.step(order[0]); // gcc 14.1 with RA_DO_CHECK=0 and sanitizer on
 // not worth unrolling.
-            if constexpr (requires {early.def;}) {
-                return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
-            } else {
-                subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
-            }
-#pragma GCC diagnostic pop
+        if constexpr (requires {early.def;}) {
+            return (subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early)).value_or(early.def);
+        } else {
+            subply<order, rank-1, 1>(a, a.len(order[0]), ss0, early);
         }
+#pragma GCC diagnostic pop
     }
 }
 

+ 1 - 4
test/SConstruct

@@ -53,7 +53,7 @@ tester = ra.to_test_ra(env, variant_dir)
 
 [tester(test)
  for test in ['at', 'bench', 'big-0', 'big-1', 'bug83', 'cellrank', 'checks', 'compatibility',
-              'concrete', 'const', 'constexpr', 'dual', 'explode-0', 'foreign', 'frame-new',
+              'concrete', 'const', 'constexpr', 'dual', 'early', 'explode-0', 'foreign', 'frame-new',
               'frame-old', 'fromb', 'fromu', 'headers', 'io', 'iota', 'iterator-small', 'len',
               'list9', 'macros', 'mem-fn', 'ndebug', 'nested-0', 'old', 'operators', 'optimize',
               'owned', 'ownership', 'ply', 'ra-0', 'ra-1', 'ra-10', 'ra-11', 'ra-12', 'ra-13',
@@ -66,9 +66,6 @@ tester = ra.to_test_ra(env, variant_dir)
               # 'ra-16' # FIXME broken on gcc11, ok on gcc12/13
               ]]
 
-tester('early', target='early-no-su', cppdefines={'RA_STATIC_UNROLL': '0'})
-tester('early', target='early-su', cppdefines={'RA_STATIC_UNROLL': '1'})
-
 tester('ra-10', target='ra-10a', cxxflags=['-O3'], cppdefines={'RA_DO_CHECK': '0'})
 tester('ra-10', target='ra-10b', cxxflags=['-O1'], cppdefines={'RA_DO_CHECK': '0'})
 tester('ra-10', target='ra-10c', cxxflags=['-O3'], cppdefines={'RA_DO_CHECK': '1'})

+ 5 - 0
test/checks.cc

@@ -120,6 +120,11 @@ int main()
 // see test/frame-new.cc
 // ------------------------------
 
+    tr.section("static match in dynamic case");
+    {
+        ra::Big<int> a({2, 3, 4}, 0);
+        tr.test_eq(2, agree_s(a, 99));
+    }
     tr.section("dynamic (implicit) match");
     {
         ra::Big<int, 3> a({2, 3, 4}, (ra::_0+1)*100 + (ra::_1+1)*10 + (ra::_2+1));

+ 6 - 0
test/len.cc

@@ -94,5 +94,11 @@ int main()
         // tr.test_eq(5, wlen(ra::ic<5>, ra::iota(ra::len-ra::ic<1>)).ronk()); // FIXME
         // tr.test_eq(5, wlen(ra::ic<6>, ra::iota(ra::len-ra::ic<1>)).len_s(0)); // FIXME
     }
+    // tr.section("ra::len in ... in x.len(...)");
+    // {
+    //     ra::Big<int, 3> a({2, 3, 4}, 0);
+    //     // int z = ra::len-1; // static assert; use outside subscript context
+    //     cout << a.len(ra::len-1) << endl; // FIXME
+    // }
     return tr.summary();
 }