Skip to content

Commit

Permalink
Merge pull request #506 from IntersectMBO/lehins/use-mempack
Browse files Browse the repository at this point in the history
Add support for `mempack`
  • Loading branch information
lehins authored Nov 20, 2024
2 parents 97a3b8b + 03d23d0 commit 85c2384
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 63 deletions.
2 changes: 2 additions & 0 deletions cardano-crypto-class/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
**DO NOT RELEASE YET** - MemLocking and secure forgettin interface has not yet
solidified. Ask @lehins if backport is needed.

* Add required `HashAlgorithm` constraint to `Hash` serialization.
* Add `MemPack` instance for `Hash` and `PackedBytes`
* Introduce memory locking and secure forgetting functionality:
[#255](https://github.com/input-output-hk/cardano-base/pull/255)
[#404](https://github.com/input-output-hk/cardano-base/pull/404)
Expand Down
3 changes: 2 additions & 1 deletion cardano-crypto-class/cardano-crypto-class.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ library
, heapwords
, io-classes >= 1.4.1
, memory
, mempack
, mtl
, nothunks
, primitive
, primitive >= 0.8
, serialise
, template-haskell
, th-compat
Expand Down
14 changes: 9 additions & 5 deletions cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
Expand Down Expand Up @@ -66,6 +67,8 @@ import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as Base16
import qualified Data.ByteString.Char8 as BSC
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as SBS
import Data.MemPack (StateT(StateT), FailT(FailT), MemPack, Unpack(Unpack))
import Data.Word (Word8)
import Numeric.Natural (Natural)

Expand All @@ -84,8 +87,7 @@ import Control.DeepSeq (NFData)

import NoThunks.Class (NoThunks)

import Cardano.Binary (Encoding, FromCBOR(..), Size, ToCBOR(..), decodeBytes,
serialize')
import Cardano.Binary (Encoding, FromCBOR(..), Size, ToCBOR(..), serialize')
import Cardano.Crypto.PackedBytes
import Cardano.Crypto.Util (decodeHexString)
import Cardano.HeapWords (HeapWords (..))
Expand All @@ -110,6 +112,8 @@ sizeHash _ = fromInteger (natVal (Proxy @(SizeHash h)))
newtype Hash h a = UnsafeHashRep (PackedBytes (SizeHash h))
deriving (Eq, Ord, Generic, NoThunks, NFData)

deriving instance HashAlgorithm h => MemPack (Hash h a)

-- | This instance is meant to be used with @TemplateHaskell@
--
-- >>> import Cardano.Crypto.Hash.Class (Hash)
Expand Down Expand Up @@ -353,14 +357,14 @@ instance (HashAlgorithm h, Typeable a) => ToCBOR (Hash h a) where

instance (HashAlgorithm h, Typeable a) => FromCBOR (Hash h a) where
fromCBOR = do
bs <- decodeBytes
case hashFromBytes bs of
sbs <- fromCBOR
case hashFromBytesShort sbs of
Just x -> return x
Nothing -> fail $ "hash bytes wrong size, expected " ++ show expected
++ " but got " ++ show actual
where
expected = sizeHash (Proxy :: Proxy h)
actual = BS.length bs
actual = SBS.length sbs

--
-- Deprecated
Expand Down
149 changes: 92 additions & 57 deletions cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,25 @@ module Cardano.Crypto.PackedBytes
import Codec.Serialise (Serialise(..))
import Codec.Serialise.Decoding (decodeBytes)
import Codec.Serialise.Encoding (encodeBytes)
import Control.DeepSeq
import Control.DeepSeq (NFData(..))
import Control.Monad (guard)
import Control.Monad.Primitive
import Control.Monad.Primitive (primitive_)
import Control.Monad.Reader (MonadReader(ask), MonadTrans(lift))
import Control.Monad.State.Strict (MonadState(state))
import Data.Bits
import Data.ByteString
import Data.ByteString.Internal as BS (accursedUnutterablePerformIO,
fromForeignPtr, toForeignPtr)
import Data.ByteString.Short.Internal as SBS
import Data.MemPack (guardAdvanceUnpack, st_, MemPack(..), Pack(Pack))
import Data.MemPack.Buffer (Buffer(buffer), byteArrayToShortByteString, pinnedByteArrayToForeignPtr)
import Data.Primitive.ByteArray
import Data.Primitive.PrimArray (PrimArray(..), imapPrimArray, indexPrimArray)
import Data.Typeable
import Foreign.ForeignPtr
import Foreign.Ptr (castPtr)
import Foreign.Storable (peekByteOff)
import GHC.Exts
import GHC.ForeignPtr (ForeignPtr(ForeignPtr), ForeignPtrContents(PlainPtr))
#if MIN_VERSION_base(4,15,0)
import GHC.ForeignPtr (unsafeWithForeignPtr)
#endif
Expand Down Expand Up @@ -92,9 +95,46 @@ instance NFData (PackedBytes n) where
rnf PackedBytes32 {} = ()
rnf PackedBytes# {} = ()

instance Serialise (PackedBytes n) where
instance KnownNat n => MemPack (PackedBytes n) where
packedByteCount = fromInteger @Int . natVal
{-# INLINE packedByteCount #-}
packM pb = do
let !len@(I# len#) = packedByteCount pb
i@(I# i#) <- state $ \i -> (i, i + len)
mba@(MutableByteArray mba#) <- ask
Pack $ \_ -> lift $ case pb of
PackedBytes8 w -> writeWord64BE mba i w
PackedBytes28 w0 w1 w2 w3 -> do
writeWord64BE mba i w0
writeWord64BE mba (i + 8) w1
writeWord64BE mba (i + 16) w2
writeWord32BE mba (i + 24) w3
PackedBytes32 w0 w1 w2 w3 -> do
writeWord64BE mba i w0
writeWord64BE mba (i + 8) w1
writeWord64BE mba (i + 16) w2
writeWord64BE mba (i + 24) w3
PackedBytes# ba# ->
st_ (copyByteArray# ba# 0# mba# i# len#)
{-# INLINE packM #-}
unpackM = do
let !len = fromInteger @Int $ natVal' (proxy# :: Proxy# n)
curPos@(I# curPos#) <- guardAdvanceUnpack len
buf <- ask
pure $! buffer buf
(\ba# -> packBytes (SBS.SBS ba#) curPos)
-- Usage of `accursedUnutterablePerformIO` is safe below because there are no memory
-- allocations happening that depend on the IO monad that we are excaping here. All
-- IO actions are morally pure reads using pointers into the immutable
-- memory. Furthermore, in the place where ByteArray is allocated in
-- `packPinnedPtrN`, mutation and freezing are encapsulated with `runST` and is not
-- related to the `IO` we are escaping.
(\addr# -> accursedUnutterablePerformIO $ packPinnedPtr (Ptr (addr# `plusAddr#` curPos#)))
{-# INLINE unpackM #-}

instance KnownNat n => Serialise (PackedBytes n) where
encode = encodeBytes . unpackPinnedBytes
decode = packPinnedBytesN <$> decodeBytes
decode = packPinnedBytes <$> decodeBytes

xorPackedBytes :: PackedBytes n -> PackedBytes n -> PackedBytes n
xorPackedBytes (PackedBytes8 x) (PackedBytes8 y) = PackedBytes8 (x `xor` y)
Expand Down Expand Up @@ -219,55 +259,59 @@ packBytesMaybe bs offset = do
guard (offset >= 0)
guard (size <= bufferSize - offset)
Just $ packBytes bs offset
{-# INLINE packBytesMaybe #-}


packPinnedPtr8 :: Ptr a -> IO (PackedBytes 8)
packPinnedPtr8 = fmap PackedBytes8 . (`peekWord64BE` 0)
{-# INLINE packPinnedPtr8 #-}

packPinnedPtr28 :: Ptr a -> IO (PackedBytes 28)
packPinnedPtr28 ptr =
PackedBytes28
<$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord32BE ptr 24
{-# INLINE packPinnedPtr28 #-}

packPinnedPtr32 :: Ptr a -> IO (PackedBytes 32)
packPinnedPtr32 ptr =
PackedBytes32 <$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord64BE ptr 24
{-# INLINE packPinnedPtr32 #-}

packPinnedPtrN :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n)
packPinnedPtrN (Ptr addr#) = pure $! PackedBytes# ba#
where
!(ByteArray ba#) = withMutableByteArray len $ \(MutableByteArray mba#) ->
st_ (copyAddrToByteArray# addr# mba# 0# len#)
!len@(I# len#) = fromInteger (natVal' (proxy# :: Proxy# n))
{-# INLINE packPinnedPtrN #-}


packPinnedBytes8 :: ByteString -> PackedBytes 8
packPinnedBytes8 bs = unsafeWithByteStringPtr bs (fmap PackedBytes8 . (`peekWord64BE` 0))
{-# INLINE packPinnedBytes8 #-}

packPinnedBytes28 :: ByteString -> PackedBytes 28
packPinnedBytes28 bs =
unsafeWithByteStringPtr bs $ \ptr ->
PackedBytes28
<$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord32BE ptr 24
{-# INLINE packPinnedBytes28 #-}

packPinnedBytes32 :: ByteString -> PackedBytes 32
packPinnedBytes32 bs =
unsafeWithByteStringPtr bs $ \ptr -> PackedBytes32 <$> peekWord64BE ptr 0
<*> peekWord64BE ptr 8
<*> peekWord64BE ptr 16
<*> peekWord64BE ptr 24
{-# INLINE packPinnedBytes32 #-}

packPinnedBytesN :: ByteString -> PackedBytes n
packPinnedBytesN bs =
case toShort bs of
SBS ba# -> PackedBytes# ba#
{-# INLINE packPinnedBytesN #-}


packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n
packPinnedBytes bs =
packPinnedPtr :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n)
packPinnedPtr bs =
let px = Proxy :: Proxy n
in case sameNat px (Proxy :: Proxy 8) of
Just Refl -> packPinnedBytes8 bs
Just Refl -> packPinnedPtr8 bs
Nothing -> case sameNat px (Proxy :: Proxy 28) of
Just Refl -> packPinnedBytes28 bs
Just Refl -> packPinnedPtr28 bs
Nothing -> case sameNat px (Proxy :: Proxy 32) of
Just Refl -> packPinnedBytes32 bs
Nothing -> packPinnedBytesN bs
{-# INLINE[1] packPinnedBytes #-}

Just Refl -> packPinnedPtr32 bs
Nothing -> packPinnedPtrN bs
{-# INLINE[1] packPinnedPtr #-}
{-# RULES
"packPinnedBytes8" packPinnedBytes = packPinnedBytes8
"packPinnedBytes28" packPinnedBytes = packPinnedBytes28
"packPinnedBytes32" packPinnedBytes = packPinnedBytes32
"packPinnedPtr8" packPinnedPtr = packPinnedPtr8
"packPinnedPtr28" packPinnedPtr = packPinnedPtr28
"packPinnedPtr32" packPinnedPtr = packPinnedPtr32
#-}

packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n
packPinnedBytes bs = unsafeWithByteStringPtr bs packPinnedPtr
{-# INLINE packPinnedBytes #-}


--- Primitive architecture agnostic helpers

Expand Down Expand Up @@ -358,22 +402,13 @@ writeWord32BE (MutableByteArray mba#) (I# i#) w =
#endif
{-# INLINE writeWord32BE #-}

byteArrayToShortByteString :: ByteArray -> ShortByteString
byteArrayToShortByteString (ByteArray ba#) = SBS ba#
{-# INLINE byteArrayToShortByteString #-}

byteArrayToByteString :: ByteArray -> ByteString
byteArrayToByteString ba
byteArrayToByteString ba@(ByteArray ba#)
| isByteArrayPinned ba =
BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba) 0 (sizeofByteArray ba)
BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba#) 0 (sizeofByteArray ba)
| otherwise = SBS.fromShort (byteArrayToShortByteString ba)
{-# INLINE byteArrayToByteString #-}

pinnedByteArrayToForeignPtr :: ByteArray -> ForeignPtr a
pinnedByteArrayToForeignPtr (ByteArray ba#) =
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}

-- Usage of `accursedUnutterablePerformIO` here is safe because we only use it
-- for indexing into an immutable `ByteString`, which is analogous to
-- `Data.ByteString.index`. Make sure you know what you are doing before using
Expand Down
1 change: 1 addition & 0 deletions cardano-crypto-tests/cardano-crypto-tests.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ library
, deepseq
, formatting
, io-classes >= 1.4.0
, mempack
, mtl
, nothunks
, pretty-show
Expand Down
7 changes: 7 additions & 0 deletions cardano-crypto-tests/src/Test/Crypto/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import qualified Data.Bits as Bits (xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as SBS
import Data.Maybe (fromJust)
import Data.MemPack
import Data.Proxy (Proxy(..))
import Data.String (fromString)
import GHC.TypeLits
Expand Down Expand Up @@ -63,9 +64,15 @@ testHashAlgorithm p =
, testProperty "hashFromStringAsHex/fromString" $ prop_hash_hashFromStringAsHex_fromString @h @Float
, testProperty "show/read" $ prop_hash_show_read @h @Float
, testProperty "NoThunks" $ prop_no_thunks @(Hash h Int)
, testProperty "MemPack RoundTrip" $ prop_MemPackRoundTrip @(Hash h Int)
]
where n = hashAlgorithmName p

prop_MemPackRoundTrip :: forall a. (MemPack a, Eq a, Show a) => a -> Property
prop_MemPackRoundTrip a =
unpackError (pack a) === a .&&.
unpackError (packByteString a) === a

testSodiumHashAlgorithm
:: forall proxy h. NaCl.SodiumHashAlgorithm h
=> Lock
Expand Down

0 comments on commit 85c2384

Please sign in to comment.