about summary refs log tree commit diff
path: root/third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs')
-rw-r--r--third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs239
1 files changed, 239 insertions, 0 deletions
diff --git a/third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs b/third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs
new file mode 100644
index 0000000000..8debff23a9
--- /dev/null
+++ b/third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs
@@ -0,0 +1,239 @@
+{-# LANGUAGE ParallelListComp #-}
+module Main where
+
+import Text.PrettyPrint
+
+import System.Environment ( getArgs )
+
+main = do
+         [s] <- getArgs
+         let n = read s
+         mapM_ (putStrLn . render . generate) [2..n]
+
+generate :: Int -> Doc
+generate n =
+  vcat [ text "#ifdef DEFINE_INSTANCES"
+       , data_instance "MVector s" "MV"
+       , data_instance "Vector" "V"
+       , class_instance "Unbox"
+       , class_instance "M.MVector MVector" <+> text "where"
+       , nest 2 $ vcat $ map method methods_MVector
+       , class_instance "G.Vector Vector" <+> text "where"
+       , nest 2 $ vcat $ map method methods_Vector
+       , text "#endif"
+       , text "#ifdef DEFINE_MUTABLE"
+       , define_zip "MVector s" "MV"
+       , define_unzip "MVector s" "MV"
+       , text "#endif"
+       , text "#ifdef DEFINE_IMMUTABLE"
+       , define_zip "Vector" "V"
+       , define_zip_rule
+       , define_unzip "Vector" "V"
+       , text "#endif"
+       ]
+
+  where
+    vars  = map (\c -> text ['_',c]) $ take n ['a'..]
+    varss = map (<> char 's') vars
+    tuple xs = parens $ hsep $ punctuate comma xs
+    vtuple xs = parens $ sep $ punctuate comma xs
+    con s = text s <> char '_' <> int n
+    var c = text ('_' : c : "_")
+
+    data_instance ty c
+      = hang (hsep [text "data instance", text ty, tuple vars])
+             4
+             (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
+                   , vcat $ map (\v -> char '!' <> parens (text ty <+> v)) vars])
+
+    class_instance cls
+      = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
+                        <+> text "=>" <+> text cls <+> tuple vars
+
+
+    define_zip ty c
+      = sep [text "-- | /O(1)/ Zip" <+> int n <+> text "vectors"
+            ,name <+> text "::"
+                  <+> vtuple [text "Unbox" <+> v | v <- vars]
+                  <+> text "=>"
+                  <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
+                  <+> text "->"
+                  <+> text ty <+> tuple vars
+             ,text "{-# INLINE_FUSED"  <+> name <+> text "#-}"
+             ,name <+> sep varss
+                   <+> text "="
+                   <+> con c
+                   <+> text "len"
+                   <+> sep [parens $ text "unsafeSlice"
+                                     <+> char '0'
+                                     <+> text "len"
+                                     <+> vs | vs <- varss]
+             ,nest 2 $ hang (text "where")
+                            2
+                     $ text "len ="
+                       <+> sep (punctuate (text " `delayed_min`")
+                                          [text "length" <+> vs | vs <- varss])
+             ]
+      where
+        name | n == 2    = text "zip"
+             | otherwise = text "zip" <> int n
+
+    define_zip_rule
+      = hang (text "{-# RULES" <+> text "\"stream/" <> name "zip"
+              <> text " [Vector.Unboxed]\" forall" <+> sep varss <+> char '.')
+             2 $
+             text "G.stream" <+> parens (name "zip" <+> sep varss)
+             <+> char '='
+             <+> text "Bundle." <> name "zipWith" <+> tuple (replicate n empty)
+             <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
+             $$ text "#-}"
+     where
+       name s | n == 2    = text s
+              | otherwise = text s <> int n
+       
+
+    define_unzip ty c
+      = sep [text "-- | /O(1)/ Unzip" <+> int n <+> text "vectors"
+            ,name <+> text "::"
+                  <+> vtuple [text "Unbox" <+> v | v <- vars]
+                  <+> text "=>"
+                  <+> text ty <+> tuple vars
+                  <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
+            ,text "{-# INLINE" <+> name <+> text "#-}"
+            ,name <+> pat c <+> text "="
+                  <+> vtuple varss
+            ]
+      where
+        name | n == 2    = text "unzip"
+             | otherwise = text "unzip" <> int n
+
+    pat c = parens $ con c <+> var 'n' <+> sep varss
+    patn c n = parens $ con c <+> (var 'n' <> int n)
+                              <+> sep [v <> int n | v <- varss]
+
+    qM s = text "M." <> text s
+    qG s = text "G." <> text s
+
+    gen_length c _ = (pat c, var 'n')
+
+    gen_unsafeSlice mod c rec
+      = (var 'i' <+> var 'm' <+> pat c,
+         con c <+> var 'm'
+               <+> vcat [parens
+                         $ text mod <> char '.' <> text rec
+                                    <+> var 'i' <+> var 'm' <+> vs
+                                        | vs <- varss])
+
+
+    gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
+                        vcat $ r : [text "||" <+> r | r <- rs])
+      where
+        r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]
+
+    gen_unsafeNew rec
+      = (var 'n',
+         mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
+               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
+
+    gen_unsafeReplicate rec
+      = (var 'n' <+> tuple vars,
+         mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
+                        | v  <- vars | vs <- varss]
+               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
+
+    gen_unsafeRead rec
+      = (pat "MV" <+> var 'i',
+         mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v  <- vars
+                                                              | vs <- varss]
+               $ text "return" <+> tuple vars)
+
+    gen_unsafeWrite rec
+      = (pat "MV" <+> var 'i' <+> tuple vars,
+         mk_do [qM rec <+> vs <+> var 'i' <+> v | v  <- vars | vs <- varss]
+               empty)
+
+    gen_clear rec
+      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
+
+    gen_set rec
+      = (pat "MV" <+> tuple vars,
+         mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)
+
+    gen_unsafeCopy c q rec
+      = (patn "MV" 1 <+> patn c 2,
+         mk_do [q rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
+               empty)
+
+    gen_unsafeMove rec
+      = (patn "MV" 1 <+> patn "MV" 2,
+         mk_do [qM rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
+               empty)
+
+    gen_unsafeGrow rec
+      = (pat "MV" <+> var 'm',
+         mk_do [vs <> char '\'' <+> text "<-"
+                                <+> qM rec <+> vs <+> var 'm' | vs <- varss]
+               $ text "return $" <+> con "MV"
+                                 <+> parens (var 'm' <> char '+' <> var 'n')
+                                 <+> sep (map (<> char '\'') varss))
+
+    gen_initialize rec
+      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
+
+    gen_unsafeFreeze rec
+      = (pat "MV",
+         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
+               $ text "return $" <+> con "V" <+> var 'n'
+                                 <+> sep [vs <> char '\'' | vs <- varss])
+
+    gen_unsafeThaw rec
+      = (pat "V",
+         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
+               $ text "return $" <+> con "MV" <+> var 'n'
+                                 <+> sep [vs <> char '\'' | vs <- varss])
+
+    gen_basicUnsafeIndexM rec
+      = (pat "V" <+> var 'i',
+         mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
+                        | vs <- varss | v <- vars]
+               $ text "return" <+> tuple vars)
+
+    gen_elemseq rec
+      = (char '_' <+> tuple vars,
+         vcat $ r : [char '.' <+> r | r <- rs])
+      where
+        r : rs = [qG rec <+> parens (text "undefined :: Vector" <+> v)
+                         <+> v | v <- vars]
+
+    mk_do cmds ret = hang (text "do")
+                          2
+                          $ vcat $ cmds ++ [ret]
+
+    method (s, f) = case f s of
+                      (p,e) ->  text "{-# INLINE" <+> text s <+> text " #-}"
+                                $$ hang (text s <+> p)
+                                   4
+                                   (char '=' <+> e)
+                             
+
+    methods_MVector = [("basicLength",            gen_length "MV")
+                      ,("basicUnsafeSlice",       gen_unsafeSlice "M" "MV")
+                      ,("basicOverlaps",          gen_overlaps)
+                      ,("basicUnsafeNew",         gen_unsafeNew)
+                      ,("basicUnsafeReplicate",   gen_unsafeReplicate)
+                      ,("basicUnsafeRead",        gen_unsafeRead)
+                      ,("basicUnsafeWrite",       gen_unsafeWrite)
+                      ,("basicClear",             gen_clear)
+                      ,("basicSet",               gen_set)
+                      ,("basicUnsafeCopy",        gen_unsafeCopy "MV" qM)
+                      ,("basicUnsafeMove",        gen_unsafeMove)
+                      ,("basicUnsafeGrow",        gen_unsafeGrow)
+                      ,("basicInitialize",        gen_initialize)]
+
+    methods_Vector  = [("basicUnsafeFreeze",      gen_unsafeFreeze)
+                      ,("basicUnsafeThaw",        gen_unsafeThaw)
+                      ,("basicLength",            gen_length "V")
+                      ,("basicUnsafeSlice",       gen_unsafeSlice "G" "V")
+                      ,("basicUnsafeIndexM",      gen_basicUnsafeIndexM)
+                      ,("basicUnsafeCopy",        gen_unsafeCopy "V" qG)
+                      ,("elemseq",                gen_elemseq)]