summary refs log blame commit diff
path: root/third_party/bazel/rules_haskell/examples/vector/internal/GenUnboxTuple.hs
blob: 8debff23a97529e2db4558b503442b6218c2b862 (plain) (tree)














































































































































































































































                                                                                 
{-# LANGUAGE ParallelListComp #-}
module Main where

import Text.PrettyPrint

import System.Environment ( getArgs )

main = do
         [s] <- getArgs
         let n = read s
         mapM_ (putStrLn . render . generate) [2..n]

generate :: Int -> Doc
generate n =
  vcat [ text "#ifdef DEFINE_INSTANCES"
       , data_instance "MVector s" "MV"
       , data_instance "Vector" "V"
       , class_instance "Unbox"
       , class_instance "M.MVector MVector" <+> text "where"
       , nest 2 $ vcat $ map method methods_MVector
       , class_instance "G.Vector Vector" <+> text "where"
       , nest 2 $ vcat $ map method methods_Vector
       , text "#endif"
       , text "#ifdef DEFINE_MUTABLE"
       , define_zip "MVector s" "MV"
       , define_unzip "MVector s" "MV"
       , text "#endif"
       , text "#ifdef DEFINE_IMMUTABLE"
       , define_zip "Vector" "V"
       , define_zip_rule
       , define_unzip "Vector" "V"
       , text "#endif"
       ]

  where
    vars  = map (\c -> text ['_',c]) $ take n ['a'..]
    varss = map (<> char 's') vars
    tuple xs = parens $ hsep $ punctuate comma xs
    vtuple xs = parens $ sep $ punctuate comma xs
    con s = text s <> char '_' <> int n
    var c = text ('_' : c : "_")

    data_instance ty c
      = hang (hsep [text "data instance", text ty, tuple vars])
             4
             (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
                   , vcat $ map (\v -> char '!' <> parens (text ty <+> v)) vars])

    class_instance cls
      = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
                        <+> text "=>" <+> text cls <+> tuple vars


    define_zip ty c
      = sep [text "-- | /O(1)/ Zip" <+> int n <+> text "vectors"
            ,name <+> text "::"
                  <+> vtuple [text "Unbox" <+> v | v <- vars]
                  <+> text "=>"
                  <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
                  <+> text "->"
                  <+> text ty <+> tuple vars
             ,text "{-# INLINE_FUSED"  <+> name <+> text "#-}"
             ,name <+> sep varss
                   <+> text "="
                   <+> con c
                   <+> text "len"
                   <+> sep [parens $ text "unsafeSlice"
                                     <+> char '0'
                                     <+> text "len"
                                     <+> vs | vs <- varss]
             ,nest 2 $ hang (text "where")
                            2
                     $ text "len ="
                       <+> sep (punctuate (text " `delayed_min`")
                                          [text "length" <+> vs | vs <- varss])
             ]
      where
        name | n == 2    = text "zip"
             | otherwise = text "zip" <> int n

    define_zip_rule
      = hang (text "{-# RULES" <+> text "\"stream/" <> name "zip"
              <> text " [Vector.Unboxed]\" forall" <+> sep varss <+> char '.')
             2 $
             text "G.stream" <+> parens (name "zip" <+> sep varss)
             <+> char '='
             <+> text "Bundle." <> name "zipWith" <+> tuple (replicate n empty)
             <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
             $$ text "#-}"
     where
       name s | n == 2    = text s
              | otherwise = text s <> int n
       

    define_unzip ty c
      = sep [text "-- | /O(1)/ Unzip" <+> int n <+> text "vectors"
            ,name <+> text "::"
                  <+> vtuple [text "Unbox" <+> v | v <- vars]
                  <+> text "=>"
                  <+> text ty <+> tuple vars
                  <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
            ,text "{-# INLINE" <+> name <+> text "#-}"
            ,name <+> pat c <+> text "="
                  <+> vtuple varss
            ]
      where
        name | n == 2    = text "unzip"
             | otherwise = text "unzip" <> int n

    pat c = parens $ con c <+> var 'n' <+> sep varss
    patn c n = parens $ con c <+> (var 'n' <> int n)
                              <+> sep [v <> int n | v <- varss]

    qM s = text "M." <> text s
    qG s = text "G." <> text s

    gen_length c _ = (pat c, var 'n')

    gen_unsafeSlice mod c rec
      = (var 'i' <+> var 'm' <+> pat c,
         con c <+> var 'm'
               <+> vcat [parens
                         $ text mod <> char '.' <> text rec
                                    <+> var 'i' <+> var 'm' <+> vs
                                        | vs <- varss])


    gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
                        vcat $ r : [text "||" <+> r | r <- rs])
      where
        r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]

    gen_unsafeNew rec
      = (var 'n',
         mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)

    gen_unsafeReplicate rec
      = (var 'n' <+> tuple vars,
         mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
                        | v  <- vars | vs <- varss]
               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)

    gen_unsafeRead rec
      = (pat "MV" <+> var 'i',
         mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v  <- vars
                                                              | vs <- varss]
               $ text "return" <+> tuple vars)

    gen_unsafeWrite rec
      = (pat "MV" <+> var 'i' <+> tuple vars,
         mk_do [qM rec <+> vs <+> var 'i' <+> v | v  <- vars | vs <- varss]
               empty)

    gen_clear rec
      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)

    gen_set rec
      = (pat "MV" <+> tuple vars,
         mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)

    gen_unsafeCopy c q rec
      = (patn "MV" 1 <+> patn c 2,
         mk_do [q rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
               empty)

    gen_unsafeMove rec
      = (patn "MV" 1 <+> patn "MV" 2,
         mk_do [qM rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
               empty)

    gen_unsafeGrow rec
      = (pat "MV" <+> var 'm',
         mk_do [vs <> char '\'' <+> text "<-"
                                <+> qM rec <+> vs <+> var 'm' | vs <- varss]
               $ text "return $" <+> con "MV"
                                 <+> parens (var 'm' <> char '+' <> var 'n')
                                 <+> sep (map (<> char '\'') varss))

    gen_initialize rec
      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)

    gen_unsafeFreeze rec
      = (pat "MV",
         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
               $ text "return $" <+> con "V" <+> var 'n'
                                 <+> sep [vs <> char '\'' | vs <- varss])

    gen_unsafeThaw rec
      = (pat "V",
         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
               $ text "return $" <+> con "MV" <+> var 'n'
                                 <+> sep [vs <> char '\'' | vs <- varss])

    gen_basicUnsafeIndexM rec
      = (pat "V" <+> var 'i',
         mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
                        | vs <- varss | v <- vars]
               $ text "return" <+> tuple vars)

    gen_elemseq rec
      = (char '_' <+> tuple vars,
         vcat $ r : [char '.' <+> r | r <- rs])
      where
        r : rs = [qG rec <+> parens (text "undefined :: Vector" <+> v)
                         <+> v | v <- vars]

    mk_do cmds ret = hang (text "do")
                          2
                          $ vcat $ cmds ++ [ret]

    method (s, f) = case f s of
                      (p,e) ->  text "{-# INLINE" <+> text s <+> text " #-}"
                                $$ hang (text s <+> p)
                                   4
                                   (char '=' <+> e)
                             

    methods_MVector = [("basicLength",            gen_length "MV")
                      ,("basicUnsafeSlice",       gen_unsafeSlice "M" "MV")
                      ,("basicOverlaps",          gen_overlaps)
                      ,("basicUnsafeNew",         gen_unsafeNew)
                      ,("basicUnsafeReplicate",   gen_unsafeReplicate)
                      ,("basicUnsafeRead",        gen_unsafeRead)
                      ,("basicUnsafeWrite",       gen_unsafeWrite)
                      ,("basicClear",             gen_clear)
                      ,("basicSet",               gen_set)
                      ,("basicUnsafeCopy",        gen_unsafeCopy "MV" qM)
                      ,("basicUnsafeMove",        gen_unsafeMove)
                      ,("basicUnsafeGrow",        gen_unsafeGrow)
                      ,("basicInitialize",        gen_initialize)]

    methods_Vector  = [("basicUnsafeFreeze",      gen_unsafeFreeze)
                      ,("basicUnsafeThaw",        gen_unsafeThaw)
                      ,("basicLength",            gen_length "V")
                      ,("basicUnsafeSlice",       gen_unsafeSlice "G" "V")
                      ,("basicUnsafeIndexM",      gen_basicUnsafeIndexM)
                      ,("basicUnsafeCopy",        gen_unsafeCopy "V" qG)
                      ,("elemseq",                gen_elemseq)]