Skip to content

Commit

Permalink
Fix #2125.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Mar 23, 2024
1 parent 30ce414 commit 36acf29
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

* Incremental Flattening generates fewer redundant code versions.

* Better simplification of slices. (#2125)

### Removed

### Changed
Expand Down
25 changes: 25 additions & 0 deletions src/Futhark/Optimise/Simplify/Rules/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Futhark.Optimise.Simplify.Rules.Index
)
where

import Control.Monad (guard)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
Expand All @@ -31,6 +32,14 @@ data IndexResult
= IndexResult Certs VName (Slice SubExp)
| SubExpResult Certs SubExp

-- Fake expressions that we can recognise.
fakeIndices :: [TPrimExp Int64 VName]
fakeIndices = map f [0 :: Int ..]
where
f i = isInt64 $ LeafExp (VName v (negate i)) $ IntType Int64
where
v = nameFromText ("fake_" <> showText i)

-- | Try to simplify an index operation.
simplifyIndexing ::
(MonadBuilder m) =>
Expand Down Expand Up @@ -60,6 +69,22 @@ simplifyIndexing vtable seType idd (Slice inds) consuming consumed =
Just $
IndexResult cs arr . Slice . map DimFix
<$> mapM (toSubExp "index_primexp") inds''
| Just (ST.IndexedArray cs arr inds'') <-
ST.index' idd (fixSlice (pe64 <$> Slice inds) (map fst matches)) vtable,
all (worthInlining . untyped) inds'',
arr `ST.available` vtable,
all (`ST.elem` vtable) (unCerts cs),
Just inds''' <- mapM okIdx inds'' -> do
Just $ IndexResult cs arr . Slice <$> sequence inds'''
where
matches = zip fakeIndices $ sliceDims $ Slice inds
okIdx i =
case lookup i matches of
Just w ->
Just $ pure $ DimSlice (constant (0 :: Int64)) w (constant (1 :: Int64))
Nothing -> do
guard $ not $ any ((`namesIntersect` freeIn i) . freeIn . fst) matches
Just $ DimFix <$> toSubExp "index_primexp" i
Nothing -> Nothing
Just (SubExp (Var v), cs) ->
Just $ pure $ IndexResult cs v $ Slice inds
Expand Down
26 changes: 26 additions & 0 deletions tests/issue2125.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Two things are necessary here:
--
-- 1) not generating unnecessary versions.
--
-- 2) simplifying away the slices.
--
-- ==
-- structure gpu { /If/True/SegMap 1 /If/False/SegRed 1 }

entry example_tc5
[A][B][I][J]
[Q]
(xsss: [Q][A][I]f32)
(ysss: [B][Q][J]f32)
: [I][B][J][A]f32 =

#[unsafe]
map (\i -> -- dim 0
map (\b -> -- dim 1
map (\j -> -- dim 2
map (\a -> -- dim 3
map2 (*) xsss[:, a, i] ysss[b, :, j] |> f32.sum
) (iota A)
) (iota J)
) (iota B)
) (iota I)

0 comments on commit 36acf29

Please sign in to comment.