diff --git a/cardano-crypto-class/CHANGELOG.md b/cardano-crypto-class/CHANGELOG.md index a61b593b0..b834edbc6 100644 --- a/cardano-crypto-class/CHANGELOG.md +++ b/cardano-crypto-class/CHANGELOG.md @@ -20,6 +20,9 @@ solidified. Ask @lehins if backport is needed. [#404](https://github.com/input-output-hk/cardano-base/pull/404) * Restructuring of libsodium bindings and related APIs: [#404](https://github.com/input-output-hk/cardano-base/pull/404) +* Re-introduction of non-mlocked KES implementations to support a smoother + migration path: + [#504](https://github.com/IntersectMBO/cardano-base/pull/504) ## 2.1.0.2 diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 5d98c540f..5be701d9a 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -39,6 +39,7 @@ library import: base, project-config hs-source-dirs: src exposed-modules: + Cardano.Crypto.DirectSerialise Cardano.Crypto.DSIGN Cardano.Crypto.DSIGN.Class Cardano.Crypto.DSIGN.Ed25519 diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs index 01aa4a4db..b17a70a0e 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -66,14 +67,17 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.PinnedSizedBytes ( PinnedSizedBytes , psbUseAsSizedPtr + , psbUseAsCPtrLen , psbToByteString , psbFromByteStringCheck + , psbCreate , psbCreateSized , psbCreateSizedResult ) import Cardano.Crypto.Seed import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Foreign +import Cardano.Crypto.DirectSerialise @@ -261,7 +265,7 @@ instance DSIGNMAlgorithm Ed25519DSIGN where stToIO $ do cOrError $ unsafeIOToST $ c_crypto_sign_ed25519_sk_to_pk pkPtr skPtr - throwOnErrno "deriveVerKeyDSIGNM @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno + throwOnErrno "deriveVerKeyDSIGN @Ed25519DSIGN" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno return psb @@ -365,3 +369,48 @@ instance TypeError ('Text "CBOR encoding would violate mlocking guarantees") instance TypeError ('Text "CBOR decoding would violate mlocking guarantees") => FromCBOR (SignKeyDSIGNM Ed25519DSIGN) where fromCBOR = error "unsupported" + +instance DirectSerialise (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. The + -- latter contains both the seed and the 32-byte verification key, which is + -- convenient, but redundant, since we can always reconstruct it from the + -- seed. This is also reflected in the 'SizeSignKeyDSIGNM', which equals + -- 'SeedSizeDSIGNM' == 32, rather than reporting the in-memory size of 64. + directSerialise push sk = do + bracket + (getSeedDSIGNM (Proxy @Ed25519DSIGN) sk) + mlockedSeedFinalize + (\seed -> mlockedSeedUseAsCPtr seed $ \ptr -> + push + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN))) + +instance DirectDeserialise (SignKeyDSIGNM Ed25519DSIGN) where + -- /Note:/ We only serialize the 32-byte seed, not the full 64-byte key. See + -- the DirectSerialise instance above. + directDeserialise pull = do + bracket + mlockedSeedNew + mlockedSeedFinalize + (\seed -> do + mlockedSeedUseAsCPtr seed $ \ptr -> do + pull + (castPtr ptr) + (fromIntegral $ seedSizeDSIGN (Proxy @Ed25519DSIGN)) + genKeyDSIGNM seed + ) + +instance DirectSerialise (VerKeyDSIGN Ed25519DSIGN) where + directSerialise push (VerKeyEd25519DSIGN psb) = do + psbUseAsCPtrLen psb $ \ptr _ -> + push + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + +instance DirectDeserialise (VerKeyDSIGN Ed25519DSIGN) where + directDeserialise pull = do + psb <- psbCreate $ \ptr -> + pull + (castPtr ptr) + (fromIntegral $ sizeVerKeyDSIGN (Proxy @Ed25519DSIGN)) + return $! VerKeyEd25519DSIGN psb diff --git a/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs new file mode 100644 index 000000000..a07855c6d --- /dev/null +++ b/cardano-crypto-class/src/Cardano/Crypto/DirectSerialise.hs @@ -0,0 +1,195 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Direct (de-)serialisation to / from raw memory. +-- +-- The purpose of the typeclasses in this module is to abstract over data +-- structures that can expose the data they store as one or more raw 'Ptr's, +-- without any additional memory copying or conversion to intermediate data +-- structures. +-- +-- This is useful for transmitting data like KES SignKeys over a socket +-- connection: by accessing the memory directly and copying it into or out of +-- a file descriptor, without going through an intermediate @ByteString@ +-- representation (or other data structure that resides in the GHC heap), we +-- can more easily assure that the data is never written to disk, including +-- swap, which is an important requirement for KES. +module Cardano.Crypto.DirectSerialise +where + +import Foreign.Ptr +import Foreign.C.Types +import Control.Monad (when) +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Monad.Class.MonadST (MonadST, stToIO) +import Control.Exception +import Data.STRef (newSTRef, readSTRef, writeSTRef) +import Cardano.Crypto.Libsodium.Memory (copyMem) + +data SizeCheckException = + SizeCheckException + { expectedSize :: Int + , actualSize :: Int + } + deriving (Show) + +instance Exception SizeCheckException where + +sizeCheckFailed :: Int -> Int -> m () +sizeCheckFailed ex ac = + throw $ SizeCheckException ex ac + +-- | Direct deserialization from raw memory. +-- +-- @directDeserialise f@ should allocate a new value of type 'a', and +-- call @f@ with a pointer to the raw memory to be filled. @f@ may be called +-- multiple times, for data structures that store their data in multiple +-- non-contiguous blocks of memory. +-- +-- The order in which memory blocks are visited matters. +class DirectDeserialise a where + directDeserialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> m a + +-- | Direct serialization to raw memory. +-- +-- @directSerialise f x@ should call @f@ to expose the raw memory underyling +-- @x@. For data types that store their data in multiple non-contiguous blocks +-- of memory, @f@ may be called multiple times, once for each block. +-- +-- The order in which memory blocks are visited matters. +class DirectSerialise a where + directSerialise :: (MonadST m, MonadThrow m) => (Ptr CChar -> CSize -> m ()) -> a -> m () + +-- | Helper function for bounds-checked serialization. +-- Verifies that no more than the maximum number of bytes are written, and +-- returns the actual number of bytes written. +directSerialiseTo :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => (Int -> Ptr CChar -> CSize -> m ()) + -> Int + -> a + -> m Int +directSerialiseTo writeBytes dstsize val = do + posRef <- stToIO $ newSTRef 0 + let pusher :: Ptr CChar -> CSize -> m () + pusher src srcsize = do + pos <- stToIO $ readSTRef posRef + let pos' = pos + fromIntegral srcsize + when (pos' > dstsize) $ + sizeCheckFailed (dstsize - pos) (pos' - pos) + writeBytes pos src (fromIntegral srcsize) + stToIO $ writeSTRef posRef pos' + directSerialise pusher val + stToIO $ readSTRef posRef + +-- | Helper function for size-checked serialization. +-- Verifies that exactly the specified number of bytes are written. +directSerialiseToChecked :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => (Int -> Ptr CChar -> CSize -> m ()) + -> Int + -> a + -> m () +directSerialiseToChecked writeBytes dstsize val = do + bytesWritten <- directSerialiseTo writeBytes dstsize val + when (bytesWritten /= dstsize) $ + sizeCheckFailed dstsize bytesWritten + +-- | Helper function for the common use case of serializing to an in-memory +-- buffer. +-- Verifies that no more than the maximum number of bytes are written, and +-- returns the actual number of bytes written. +directSerialiseBuf :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => Ptr CChar + -> Int + -> a + -> m Int +directSerialiseBuf dst = + directSerialiseTo (copyMem . plusPtr dst) + +-- | Helper function for size-checked serialization to an in-memory buffer. +-- Verifies that exactly the specified number of bytes are written. +directSerialiseBufChecked :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => Ptr CChar + -> Int + -> a + -> m () +directSerialiseBufChecked buf dstsize val = do + bytesWritten <- directSerialiseBuf buf dstsize val + when (bytesWritten /= dstsize) $ + sizeCheckFailed dstsize bytesWritten + +-- | Helper function for size-checked deserialization. +-- Verifies that no more than the maximum number of bytes are read, and returns +-- the actual number of bytes read. +directDeserialiseFrom :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => (Int -> Ptr CChar -> CSize -> m ()) + -> Int + -> m (a, Int) +directDeserialiseFrom readBytes srcsize = do + posRef <- stToIO $ newSTRef 0 + let puller :: Ptr CChar -> CSize -> m () + puller dst dstsize = do + pos <- stToIO $ readSTRef posRef + let pos' = pos + fromIntegral dstsize + when (pos' > srcsize) $ + sizeCheckFailed (srcsize - pos) (pos' - pos) + readBytes pos dst (fromIntegral dstsize) + stToIO $ writeSTRef posRef pos' + (,) <$> directDeserialise puller <*> stToIO (readSTRef posRef) + +-- | Helper function for size-checked deserialization. +-- Verifies that exactly the specified number of bytes are read. +directDeserialiseFromChecked :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => (Int -> Ptr CChar -> CSize -> m ()) + -> Int + -> m a +directDeserialiseFromChecked readBytes srcsize = do + (r, bytesRead) <- directDeserialiseFrom readBytes srcsize + when (bytesRead /= srcsize) $ + sizeCheckFailed srcsize bytesRead + return r + +-- | Helper function for the common use case of deserializing from an in-memory +-- buffer. +-- Verifies that no more than the maximum number of bytes are read, and returns +-- the actual number of bytes read. +directDeserialiseBuf :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => Ptr CChar + -> Int + -> m (a, Int) +directDeserialiseBuf src = + directDeserialiseFrom (\pos dst -> copyMem dst (plusPtr src pos)) + +-- | Helper function for size-checked deserialization from an in-memory buffer. +-- Verifies that exactly the specified number of bytes are read. +directDeserialiseBufChecked :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => Ptr CChar + -> Int + -> m a +directDeserialiseBufChecked buf srcsize = do + (r, bytesRead) <- directDeserialiseBuf buf srcsize + when (bytesRead /= srcsize) $ + sizeCheckFailed srcsize bytesRead + return r diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs index 81cb10aac..9c7f3f5c0 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs @@ -52,12 +52,18 @@ module Cardano.Crypto.KES.Class , sizeSignKeyKES , seedSizeKES - -- * Unsound API + -- * Unsound APIs + , UnsoundKESAlgorithm (..) , encodeSignKeyKES , decodeSignKeyKES , rawDeserialiseSignKeyKES + , UnsoundPureKESAlgorithm (..) + , unsoundPureSignedKES + , encodeUnsoundPureSignKeyKES + , decodeUnsoundPureSignKeyKES + -- * Utility functions -- These are used between multiple KES implementations. User code will -- most likely not need these, but they are required for recursive @@ -66,6 +72,7 @@ module Cardano.Crypto.KES.Class -- convenience. , hashPairOfVKeys , mungeName + , unsoundPureSignKeyKESToSoundSignKeyKESViaSer ) where @@ -90,6 +97,7 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium (MLockedAllocator, mlockedMalloc) import Cardano.Crypto.Hash.Class (HashAlgorithm, Hash, hashWith) import Cardano.Crypto.DSIGN.Class (failSizeCheck) +import Cardano.Crypto.Seed class ( Typeable v , Show (VerKeyKES v) @@ -277,6 +285,46 @@ updateKES updateKES = updateKESWith mlockedMalloc +-- | Pure implementations of the core KES operations. These are unsound, because +-- proper handling of KES secrets (seeds, sign keys) requires mlocking and +-- deterministic erasure (\"secure forgetting\"), which is not possible in pure +-- code. +-- This API is only provided for testing purposes; it must not be used to +-- generate or use real KES keys. +class KESAlgorithm v => UnsoundPureKESAlgorithm v where + data UnsoundPureSignKeyKES v :: Type + + unsoundPureSignKES + :: forall a. (Signable v a) + => ContextKES v + -> Period -- ^ The /current/ period for the key + -> a + -> UnsoundPureSignKeyKES v + -> SigKES v + + unsoundPureUpdateKES + :: ContextKES v + -> UnsoundPureSignKeyKES v + -> Period -- ^ The /current/ period for the key, not the target period. + -> Maybe (UnsoundPureSignKeyKES v) + + unsoundPureGenKeyKES + :: Seed + -> UnsoundPureSignKeyKES v + + unsoundPureDeriveVerKeyKES + :: UnsoundPureSignKeyKES v + -> VerKeyKES v + + unsoundPureSignKeyKESToSoundSignKeyKES + :: (MonadST m, MonadThrow m) + => UnsoundPureSignKeyKES v + -> m (SignKeyKES v) + + rawSerialiseUnsoundPureSignKeyKES :: UnsoundPureSignKeyKES v -> ByteString + rawDeserialiseUnsoundPureSignKeyKES :: ByteString -> Maybe (UnsoundPureSignKeyKES v) + + -- | Unsound operations on KES sign keys. These operations violate secure -- forgetting constraints by leaking secrets to unprotected memory. Consider -- using the 'DirectSerialise' / 'DirectDeserialise' APIs instead. @@ -294,6 +342,18 @@ rawDeserialiseSignKeyKES :: -> m (Maybe (SignKeyKES v)) rawDeserialiseSignKeyKES = rawDeserialiseSignKeyKESWith mlockedMalloc +-- | Helper function for implementing 'unsoundPureSignKeyKESToSoundSignKeyKES' +-- for KES algorithms that support both 'UnsoundKESAlgorithm' and +-- 'UnsoundPureKESAlgorithm'. For such KES algorithms, unsound sign keys can be +-- marshalled to sound sign keys by serializing and then deserializing them. +unsoundPureSignKeyKESToSoundSignKeyKESViaSer + :: (MonadST m, MonadThrow m, UnsoundKESAlgorithm k, UnsoundPureKESAlgorithm k) + => UnsoundPureSignKeyKES k + -> m (SignKeyKES k) +unsoundPureSignKeyKESToSoundSignKeyKESViaSer sk = + maybe (error "unsoundPureSignKeyKESToSoundSignKeyKES: deserialisation failure") return =<< + (rawDeserialiseSignKeyKES . rawSerialiseUnsoundPureSignKeyKES $ sk) + -- | Subclass for KES algorithms that embed a copy of the VerKey into the -- signature itself, rather than relying on the externally supplied VerKey @@ -362,6 +422,9 @@ instance ( TypeError ('Text "Ord not supported for verification keys, use the ha encodeVerKeyKES :: KESAlgorithm v => VerKeyKES v -> Encoding encodeVerKeyKES = encodeBytes . rawSerialiseVerKeyKES +encodeUnsoundPureSignKeyKES :: UnsoundPureKESAlgorithm v => UnsoundPureSignKeyKES v -> Encoding +encodeUnsoundPureSignKeyKES = encodeBytes . rawSerialiseUnsoundPureSignKeyKES + encodeSigKES :: KESAlgorithm v => SigKES v -> Encoding encodeSigKES = encodeBytes . rawSerialiseSigKES @@ -379,6 +442,14 @@ decodeVerKeyKES = do Nothing -> failSizeCheck "decodeVerKeyKES" "key" bs (sizeVerKeyKES (Proxy :: Proxy v)) {-# INLINE decodeVerKeyKES #-} +decodeUnsoundPureSignKeyKES :: forall v s. UnsoundPureKESAlgorithm v => Decoder s (UnsoundPureSignKeyKES v) +decodeUnsoundPureSignKeyKES = do + bs <- decodeBytes + case rawDeserialiseUnsoundPureSignKeyKES bs of + Just vk -> return vk + Nothing -> failSizeCheck "decodeUnsoundPureSignKeyKES" "key" bs (sizeSignKeyKES (Proxy :: Proxy v)) +{-# INLINE decodeUnsoundPureSignKeyKES #-} + decodeSigKES :: forall v s. KESAlgorithm v => Decoder s (SigKES v) decodeSigKES = do bs <- decodeBytes @@ -435,6 +506,16 @@ verifySignedKES -> Either String () verifySignedKES ctxt vk j a (SignedKES sig) = verifyKES ctxt vk j a sig +unsoundPureSignedKES + :: (UnsoundPureKESAlgorithm v, Signable v a) + => ContextKES v + -> Period + -> a + -> UnsoundPureSignKeyKES v + -> SignedKES v a +unsoundPureSignedKES ctxt time a key = SignedKES $ unsoundPureSignKES ctxt time a key + + encodeSignedKES :: KESAlgorithm v => SignedKES v a -> Encoding encodeSignedKES (SignedKES s) = encodeSigKES s diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs index cce1102f1..e940c4775 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs @@ -60,7 +60,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -168,6 +168,38 @@ instance ( DSIGNMAlgorithm d forgetSignKeyKESWith allocator (SignKeyCompactSingleKES v) = forgetSignKeyDSIGNMWith allocator v +instance ( KESAlgorithm (CompactSingleKES d) + , UnsoundDSIGNMAlgorithm d + ) + => UnsoundPureKESAlgorithm (CompactSingleKES d) where + data UnsoundPureSignKeyKES (CompactSingleKES d) = + UnsoundPureSignKeyCompactSingleKES (SignKeyDSIGN d) + deriving (Generic) + + unsoundPureSignKES ctxt t a (UnsoundPureSignKeyCompactSingleKES sk) = + assert (t == 0) $! + SigCompactSingleKES (signDSIGN ctxt a sk) (deriveVerKeyDSIGN sk) + + unsoundPureUpdateKES _ctx _sk _to = Nothing + + -- + -- Key generation + -- + + unsoundPureGenKeyKES seed = + UnsoundPureSignKeyCompactSingleKES $! genKeyDSIGN seed + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeyCompactSingleKES v) = + VerKeyCompactSingleKES $! deriveVerKeyDSIGN v + + unsoundPureSignKeyKESToSoundSignKeyKES = + unsoundPureSignKeyKESToSoundSignKeyKESViaSer + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeyCompactSingleKES sk) = + rawSerialiseSignKeyDSIGN sk + rawDeserialiseUnsoundPureSignKeyKES b = + UnsoundPureSignKeyCompactSingleKES <$> rawDeserialiseSignKeyDSIGN b + instance ( KESAlgorithm (CompactSingleKES d) , DSIGNMAlgorithm d ) => OptimizedKESAlgorithm (CompactSingleKES d) where @@ -227,3 +259,35 @@ instance (DSIGNMAlgorithm d, KnownNat (SizeSigKES (CompactSingleKES d))) => From slice :: Word -> Word -> ByteString -> ByteString slice offset size = BS.take (fromIntegral size) . BS.drop (fromIntegral offset) + +-- +-- UnsoundPureSignKey instances +-- + +deriving instance DSIGNAlgorithm d => Show (UnsoundPureSignKeyKES (CompactSingleKES d)) +deriving instance (DSIGNAlgorithm d, Eq (SignKeyDSIGN d)) => Eq (UnsoundPureSignKeyKES (CompactSingleKES d)) + +instance (UnsoundDSIGNMAlgorithm d, KnownNat (SizeSigDSIGN d + SizeVerKeyDSIGN d)) => ToCBOR (UnsoundPureSignKeyKES (CompactSingleKES d)) where + toCBOR = encodeUnsoundPureSignKeyKES + encodedSizeExpr _size _skProxy = encodedSignKeyKESSizeExpr (Proxy :: Proxy (SignKeyKES (CompactSingleKES d))) + +instance (UnsoundDSIGNMAlgorithm d, KnownNat (SizeSigDSIGN d + SizeVerKeyDSIGN d)) => FromCBOR (UnsoundPureSignKeyKES (CompactSingleKES d)) where + fromCBOR = decodeUnsoundPureSignKeyKES + +instance DSIGNAlgorithm d => NoThunks (UnsoundPureSignKeyKES (CompactSingleKES d)) + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (CompactSingleKES d)) where + directSerialise push (SignKeyCompactSingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (CompactSingleKES d)) where + directDeserialise pull = SignKeyCompactSingleKES <$!> directDeserialise pull + +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (CompactSingleKES d)) where + directSerialise push (VerKeyCompactSingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (CompactSingleKES d)) where + directDeserialise pull = VerKeyCompactSingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index ce37acbe8..650267641 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -6,6 +6,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} @@ -86,7 +87,8 @@ module Cardano.Crypto.KES.CompactSum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS -import Control.Monad (guard) +import qualified Data.ByteString.Internal as BS +import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) @@ -95,12 +97,17 @@ import Cardano.Crypto.Hash.Class import Cardano.Crypto.KES.Class import Cardano.Crypto.KES.CompactSingle (CompactSingleKES) import Cardano.Crypto.Util +import Cardano.Crypto.Seed import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.Monad.Trans (lift) import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type CompactSum0KES d = CompactSingleKES d @@ -461,3 +468,174 @@ instance ( OptimizedKESAlgorithm d ) => FromCBOR (SigKES (CompactSumKES h d)) where fromCBOR = decodeSigKES + + +-- +-- Unsound pure KES API +-- +instance ( KESAlgorithm (CompactSumKES h d) + , HashAlgorithm h + , UnsoundPureKESAlgorithm d + ) + => UnsoundPureKESAlgorithm (CompactSumKES h d) where + data UnsoundPureSignKeyKES (CompactSumKES h d) = + UnsoundPureSignKeyCompactSumKES !(UnsoundPureSignKeyKES d) + !Seed + !(VerKeyKES d) + !(VerKeyKES d) + deriving (Generic) + + unsoundPureSignKES ctxt t a (UnsoundPureSignKeyCompactSumKES sk _r_1 vk_0 vk_1) = + SigCompactSumKES sigma vk_other + where + (sigma, vk_other) + | t < _T = (unsoundPureSignKES ctxt t a sk, vk_1) + | otherwise = (unsoundPureSignKES ctxt (t - _T) a sk, vk_0) + + _T = totalPeriodsKES (Proxy :: Proxy d) + + unsoundPureUpdateKES ctx (UnsoundPureSignKeyCompactSumKES sk r_1 vk_0 vk_1) t + | t+1 < _T = do + sk' <- unsoundPureUpdateKES ctx sk t + let r_1' = r_1 + return $! UnsoundPureSignKeyCompactSumKES sk' r_1' vk_0 vk_1 + | t+1 == _T = do + let sk' = unsoundPureGenKeyKES r_1 + let r_1' = mkSeedFromBytes (BS.replicate (fromIntegral (seedSizeKES (Proxy @d))) 0) + return $! UnsoundPureSignKeyCompactSumKES sk' r_1' vk_0 vk_1 + | otherwise = do + sk' <- unsoundPureUpdateKES ctx sk (t - _T) + let r_1' = r_1 + return $! UnsoundPureSignKeyCompactSumKES sk' r_1' vk_0 vk_1 + where + _T = totalPeriodsKES (Proxy :: Proxy d) + + -- + -- Key generation + -- + + unsoundPureGenKeyKES r = + let r0 = mkSeedFromBytes $ digest (Proxy @h) (BS.cons 1 $ getSeedBytes r) + r1 = mkSeedFromBytes $ digest (Proxy @h) (BS.cons 2 $ getSeedBytes r) + sk_0 = unsoundPureGenKeyKES r0 + vk_0 = unsoundPureDeriveVerKeyKES sk_0 + sk_1 = unsoundPureGenKeyKES r1 + vk_1 = unsoundPureDeriveVerKeyKES sk_1 + in UnsoundPureSignKeyCompactSumKES sk_0 r1 vk_0 vk_1 + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeyCompactSumKES _ _ vk_0 vk_1) = + VerKeyCompactSumKES (hashPairOfVKeys (vk_0, vk_1)) + + unsoundPureSignKeyKESToSoundSignKeyKES (UnsoundPureSignKeyCompactSumKES sk r_1 vk_0 vk_1) = + SignKeyCompactSumKES + <$> unsoundPureSignKeyKESToSoundSignKeyKES sk + <*> (fmap MLockedSeed . mlsbFromByteString . getSeedBytes $ r_1) + <*> pure vk_0 + <*> pure vk_1 + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeyCompactSumKES sk r_1 vk_0 vk_1) = + let ssk = rawSerialiseUnsoundPureSignKeyKES sk + sr1 = getSeedBytes r_1 + in mconcat + [ ssk + , sr1 + , rawSerialiseVerKeyKES vk_0 + , rawSerialiseVerKeyKES vk_1 + ] + + rawDeserialiseUnsoundPureSignKeyKES b = do + guard (BS.length b == fromIntegral size_total) + sk <- rawDeserialiseUnsoundPureSignKeyKES b_sk + let r = mkSeedFromBytes b_r + vk_0 <- rawDeserialiseVerKeyKES b_vk0 + vk_1 <- rawDeserialiseVerKeyKES b_vk1 + return (UnsoundPureSignKeyCompactSumKES sk r vk_0 vk_1) + where + b_sk = slice off_sk size_sk b + b_r = slice off_r size_r b + b_vk0 = slice off_vk0 size_vk b + b_vk1 = slice off_vk1 size_vk b + + size_sk = sizeSignKeyKES (Proxy :: Proxy d) + size_r = seedSizeKES (Proxy :: Proxy d) + size_vk = sizeVerKeyKES (Proxy :: Proxy d) + size_total = sizeSignKeyKES (Proxy :: Proxy (CompactSumKES h d)) + + off_sk = 0 :: Word + off_r = size_sk + off_vk0 = off_r + size_r + off_vk1 = off_vk0 + size_vk + +-- +-- UnsoundPureSignKey instances +-- + +deriving instance (KESAlgorithm d, Show (UnsoundPureSignKeyKES d)) => Show (UnsoundPureSignKeyKES (CompactSumKES h d)) +deriving instance (KESAlgorithm d, Eq (UnsoundPureSignKeyKES d)) => Eq (UnsoundPureSignKeyKES (CompactSumKES h d)) + +instance ( SizeHash h ~ SeedSizeKES d + , OptimizedKESAlgorithm d + , UnsoundPureKESAlgorithm d + , SodiumHashAlgorithm h + , KnownNat (SizeVerKeyKES (CompactSumKES h d)) + , KnownNat (SizeSignKeyKES (CompactSumKES h d)) + , KnownNat (SizeSigKES (CompactSumKES h d)) + ) => ToCBOR (UnsoundPureSignKeyKES (CompactSumKES h d)) where + toCBOR = encodeUnsoundPureSignKeyKES + encodedSizeExpr _size _skProxy = encodedSignKeyKESSizeExpr (Proxy :: Proxy (SignKeyKES (CompactSumKES h d))) + +instance ( SizeHash h ~ SeedSizeKES d + , OptimizedKESAlgorithm d + , UnsoundPureKESAlgorithm d + , SodiumHashAlgorithm h + , KnownNat (SizeVerKeyKES (CompactSumKES h d)) + , KnownNat (SizeSignKeyKES (CompactSumKES h d)) + , KnownNat (SizeSigKES (CompactSumKES h d)) + ) => FromCBOR (UnsoundPureSignKeyKES (CompactSumKES h d)) where + fromCBOR = decodeUnsoundPureSignKeyKES + +instance (NoThunks (UnsoundPureSignKeyKES d), KESAlgorithm d) => NoThunks (UnsoundPureSignKeyKES (CompactSumKES h d)) + + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectSerialise (SignKeyKES (CompactSumKES h d)) where + directSerialise push (SignKeyCompactSumKES sk r vk0 vk1) = do + directSerialise push sk + directSerialise push r + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectDeserialise (SignKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + r <- directDeserialise pull + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeyCompactSumKES sk r vk0 vk1 + + +instance DirectSerialise (VerKeyKES (CompactSumKES h d)) where + directSerialise push (VerKeyCompactSumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (CompactSumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "Invalid hash") return $! VerKeyCompactSumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 4e2a91516..26bd61937 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -15,6 +15,7 @@ module Cardano.Crypto.KES.Mock ( MockKES , VerKeyKES (..) , SignKeyKES (..) + , UnsoundPureSignKeyKES (..) , SigKES (..) ) where @@ -24,6 +25,8 @@ import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import GHC.TypeNats (Nat, KnownNat, natVal) import NoThunks.Class (NoThunks) +import qualified Data.ByteString.Internal as BS +import Foreign.Ptr (castPtr) import Control.Exception (assert) @@ -35,8 +38,15 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Util import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium - ( mlsbAsByteString + ( mlsbToByteString ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , ForeignPtr (..) + , mallocForeignPtrBytes + , withForeignPtr + ) +import Cardano.Crypto.DirectSerialise data MockKES (t :: Nat) @@ -151,11 +161,56 @@ instance KnownNat t => KESAlgorithm (MockKES t) where -- genKeyKESWith _allocator seed = do - let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes . mlsbAsByteString . mlockedSeedMLSB $ seed) getRandomWord64) + seedBS <- mlsbToByteString $ mlockedSeedMLSB seed + let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes seedBS) getRandomWord64) return $! SignKeyMockKES vk 0 forgetSignKeyKESWith _ = const $ return () +instance KnownNat t => UnsoundPureKESAlgorithm (MockKES t) where + -- + -- Key and signature types + -- + + data UnsoundPureSignKeyKES (MockKES t) = + UnsoundPureSignKeyMockKES !(VerKeyKES (MockKES t)) !Period + deriving stock (Show, Eq, Generic) + deriving anyclass (NoThunks) + + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeyMockKES vk _) = vk + + unsoundPureUpdateKES () (UnsoundPureSignKeyMockKES vk t') t = + assert (t == t') $! + if t+1 < totalPeriodsKES (Proxy @(MockKES t)) + then Just $! UnsoundPureSignKeyMockKES vk (t+1) + else Nothing + + -- | Produce valid signature only with correct key, i.e., same iteration and + -- allowed KES period. + unsoundPureSignKES () t a (UnsoundPureSignKeyMockKES vk t') = + assert (t == t') $! + SigMockKES (castHash (hashWith getSignableRepresentation a)) + (SignKeyMockKES vk t) + + -- + -- Key generation + -- + + unsoundPureGenKeyKES seed = + let vk = VerKeyMockKES (runMonadRandomWithSeed seed getRandomWord64) + in UnsoundPureSignKeyMockKES vk 0 + + unsoundPureSignKeyKESToSoundSignKeyKES (UnsoundPureSignKeyMockKES vk t) = + return $ SignKeyMockKES vk t + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeyMockKES vk t) = + rawSerialiseSignKeyMockKES (SignKeyMockKES vk t) + + rawDeserialiseUnsoundPureSignKeyKES bs = do + SignKeyMockKES vt t <- rawDeserialiseSignKeyMockKES bs + return $ UnsoundPureSignKeyMockKES vt t + instance KnownNat t => UnsoundKESAlgorithm (MockKES t) where rawSerialiseSignKeyKES sk = return $ rawSerialiseSignKeyMockKES sk @@ -194,3 +249,40 @@ instance KnownNat t => ToCBOR (SigKES (MockKES t)) where instance KnownNat t => FromCBOR (SigKES (MockKES t)) where fromCBOR = decodeSigKES + +instance KnownNat t => ToCBOR (UnsoundPureSignKeyKES (MockKES t)) where + toCBOR = encodeUnsoundPureSignKeyKES + encodedSizeExpr _size _skProxy = encodedSignKeyKESSizeExpr (Proxy :: Proxy (SignKeyKES (MockKES t))) + +instance KnownNat t => FromCBOR (UnsoundPureSignKeyKES (MockKES t)) where + fromCBOR = decodeUnsoundPureSignKeyKES + +instance (KnownNat t) => DirectSerialise (SignKeyKES (MockKES t)) where + directSerialise put sk = do + let bs = rawSerialiseSignKeyMockKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> put cstr (fromIntegral len) + +instance (KnownNat t) => DirectDeserialise (SignKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeSignKeyKES (Proxy @(MockKES t)) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "directDeserialise @(SignKeyKES (MockKES t))") return $ + rawDeserialiseSignKeyMockKES bs + +instance (KnownNat t) => DirectSerialise (VerKeyKES (MockKES t)) where + directSerialise push sk = do + let bs = rawSerialiseVerKeyKES sk + unpackByteStringCStringLen bs $ \(cstr, len) -> push cstr (fromIntegral len) + +instance (KnownNat t) => DirectDeserialise (VerKeyKES (MockKES t)) where + directDeserialise pull = do + let len = fromIntegral $ sizeVerKeyKES (Proxy @(MockKES t)) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> + pull (castPtr ptr) (fromIntegral len) + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "directDeserialise @(VerKeyKES (MockKES t))") return $ + rawDeserialiseVerKeyKES bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs index 202b397b5..43ab9561e 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs @@ -65,3 +65,15 @@ instance KESAlgorithm NeverKES where instance UnsoundKESAlgorithm NeverKES where rawSerialiseSignKeyKES _ = return mempty rawDeserialiseSignKeyKESWith _ _ = return $ Just NeverUsedSignKeyKES + +instance UnsoundPureKESAlgorithm NeverKES where + data UnsoundPureSignKeyKES NeverKES = NeverUsedUnsoundPureSignKeyKES + deriving (Show, Eq, Generic, NoThunks) + + unsoundPureSignKES = error "KES not available" + unsoundPureGenKeyKES _ = NeverUsedUnsoundPureSignKeyKES + unsoundPureDeriveVerKeyKES _ = NeverUsedVerKeyKES + unsoundPureUpdateKES _ = error "KES not available" + unsoundPureSignKeyKESToSoundSignKeyKES _ = return NeverUsedSignKeyKES + rawSerialiseUnsoundPureSignKeyKES _ = mempty + rawDeserialiseUnsoundPureSignKeyKES _ = Just NeverUsedUnsoundPureSignKeyKES diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index b8bfe2186..88686febf 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -22,6 +22,7 @@ module Cardano.Crypto.KES.Simple ( SimpleKES , SigKES (..) , SignKeyKES (SignKeySimpleKES, ThunkySignKeySimpleKES) + , UnsoundPureSignKeyKES (UnsoundPureSignKeySimpleKES, UnsoundPureThunkySignKeySimpleKES) ) where @@ -43,7 +44,10 @@ import Cardano.Crypto.KES.Class import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium.MLockedBytes import Cardano.Crypto.Util +import Cardano.Crypto.Seed +import Cardano.Crypto.DirectSerialise import Data.Unit.Strict (forceElemsToWHNF) +import Data.Maybe (fromMaybe) data SimpleKES d (t :: Nat) @@ -68,6 +72,14 @@ pattern SignKeySimpleKES v <- ThunkySignKeySimpleKES v {-# COMPLETE SignKeySimpleKES #-} +-- | See 'VerKeySimpleKES'. +pattern UnsoundPureSignKeySimpleKES :: Vector (SignKeyDSIGN d) -> UnsoundPureSignKeyKES (SimpleKES d t) +pattern UnsoundPureSignKeySimpleKES v <- UnsoundPureThunkySignKeySimpleKES v + where + UnsoundPureSignKeySimpleKES v = UnsoundPureThunkySignKeySimpleKES (forceElemsToWHNF v) + +{-# COMPLETE UnsoundPureSignKeySimpleKES #-} + instance ( DSIGNMAlgorithm d , KnownNat t , KnownNat (SeedSizeDSIGN d * t) @@ -176,6 +188,63 @@ instance ( DSIGNMAlgorithm d forgetSignKeyKESWith allocator (SignKeySimpleKES sks) = Vec.mapM_ (forgetSignKeyDSIGNMWith allocator) sks +instance ( KESAlgorithm (SimpleKES d t) + , KnownNat t + , DSIGNAlgorithm d + , UnsoundDSIGNMAlgorithm d + ) + => UnsoundPureKESAlgorithm (SimpleKES d t) where + + newtype UnsoundPureSignKeyKES (SimpleKES d t) = + UnsoundPureThunkySignKeySimpleKES (Vector (SignKeyDSIGN d)) + deriving Generic + + unsoundPureGenKeyKES seed = + let seedSize = fromIntegral (seedSizeDSIGN (Proxy :: Proxy d)) + duration = fromIntegral (natVal (Proxy @t)) + seedChunk t = + mkSeedFromBytes (BS.take seedSize . BS.drop (seedSize * t) $ getSeedBytes seed) + in + UnsoundPureSignKeySimpleKES $ + Vec.generate duration (genKeyDSIGN . seedChunk) + + unsoundPureSignKES ctxt j a (UnsoundPureSignKeySimpleKES sks) = + case sks !? fromIntegral j of + Nothing -> error ("SimpleKES.unsoundPureSignKES: period out of range " ++ show j) + Just sk -> SigSimpleKES $! signDSIGN ctxt a sk + + unsoundPureUpdateKES _ (UnsoundPureThunkySignKeySimpleKES sk) t + | t+1 < fromIntegral (natVal (Proxy @t)) + = Just $! UnsoundPureThunkySignKeySimpleKES sk + | otherwise + = Nothing + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeySimpleKES sks) = + VerKeySimpleKES $! Vec.map deriveVerKeyDSIGN sks + + unsoundPureSignKeyKESToSoundSignKeyKES (UnsoundPureThunkySignKeySimpleKES sks) = do + SignKeySimpleKES <$> mapM convertSK sks + where + convertSK = fmap (fromMaybe (error "unsoundPureSignKeyKESToSoundSignKeyKES: deserialisation failed")) + . rawDeserialiseSignKeyDSIGNM + . rawSerialiseSignKeyDSIGN + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeySimpleKES sks) = + foldMap rawSerialiseSignKeyDSIGN sks + + + rawDeserialiseUnsoundPureSignKeyKES bs + | let duration = fromIntegral (natVal (Proxy :: Proxy t)) + sizeKey = fromIntegral (sizeSignKeyDSIGN (Proxy :: Proxy d)) + skbs = splitsAt (replicate duration sizeKey) bs + , length skbs == duration + = do + sks <- mapM rawDeserialiseSignKeyDSIGN skbs + return $! UnsoundPureSignKeySimpleKES (Vec.fromList sks) + + | otherwise + = Nothing + instance ( UnsoundDSIGNMAlgorithm d, KnownNat t, KESAlgorithm (SimpleKES d t)) @@ -202,13 +271,16 @@ instance ( UnsoundDSIGNMAlgorithm d, KnownNat t, KESAlgorithm (SimpleKES d t)) deriving instance DSIGNMAlgorithm d => Show (VerKeyKES (SimpleKES d t)) deriving instance (DSIGNMAlgorithm d, Show (SignKeyDSIGNM d)) => Show (SignKeyKES (SimpleKES d t)) +deriving instance (DSIGNMAlgorithm d, Show (SignKeyDSIGNM d)) => Show (UnsoundPureSignKeyKES (SimpleKES d t)) deriving instance DSIGNMAlgorithm d => Show (SigKES (SimpleKES d t)) -deriving instance DSIGNMAlgorithm d => Eq (VerKeyKES (SimpleKES d t)) -deriving instance DSIGNMAlgorithm d => Eq (SigKES (SimpleKES d t)) +deriving instance DSIGNMAlgorithm d => Eq (VerKeyKES (SimpleKES d t)) +deriving instance DSIGNMAlgorithm d => Eq (SigKES (SimpleKES d t)) +deriving instance Eq (SignKeyDSIGN d) => Eq (UnsoundPureSignKeyKES (SimpleKES d t)) instance DSIGNMAlgorithm d => NoThunks (SigKES (SimpleKES d t)) instance DSIGNMAlgorithm d => NoThunks (SignKeyKES (SimpleKES d t)) +instance DSIGNMAlgorithm d => NoThunks (UnsoundPureSignKeyKES (SimpleKES d t)) instance DSIGNMAlgorithm d => NoThunks (VerKeyKES (SimpleKES d t)) instance ( DSIGNMAlgorithm d @@ -249,3 +321,22 @@ instance (DSIGNMAlgorithm d => FromCBOR (SigKES (SimpleKES d t)) where fromCBOR = decodeSigKES +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SimpleKES d t)) where + directSerialise push (VerKeySimpleKES vks) = + mapM_ (directSerialise push) vks + +instance (DirectDeserialise (VerKeyDSIGN d), KnownNat t) => DirectDeserialise (VerKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + vks <- Vec.replicateM duration (directDeserialise pull) + return $! VerKeySimpleKES vks + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SimpleKES d t)) where + directSerialise push (SignKeySimpleKES sks) = + mapM_ (directSerialise push) sks + +instance (DirectDeserialise (SignKeyDSIGNM d), KnownNat t) => DirectDeserialise (SignKeyKES (SimpleKES d t)) where + directDeserialise pull = do + let duration = fromIntegral (natVal (Proxy :: Proxy t)) + sks <- Vec.replicateM duration (directDeserialise pull) + return $! SignKeySimpleKES sks diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs index 2d38fb527..606badaef 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs @@ -50,7 +50,7 @@ import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.Hash.Class import Cardano.Crypto.DSIGN.Class as DSIGN import Cardano.Crypto.KES.Class - +import Cardano.Crypto.DirectSerialise -- | A standard signature scheme is a forward-secure signature scheme with a -- single time period. @@ -140,6 +140,37 @@ instance (DSIGNMAlgorithm d) => KESAlgorithm (SingleKES d) where forgetSignKeyKESWith allocator (SignKeySingleKES v) = forgetSignKeyDSIGNMWith allocator v +instance ( KESAlgorithm (SingleKES d) + , UnsoundDSIGNMAlgorithm d + ) + => UnsoundPureKESAlgorithm (SingleKES d) where + newtype UnsoundPureSignKeyKES (SingleKES d) = UnsoundPureSignKeySingleKES (SignKeyDSIGN d) + deriving (Generic) + + unsoundPureSignKES ctxt t a (UnsoundPureSignKeySingleKES sk) = + assert (t == 0) $! + SigSingleKES $! signDSIGN ctxt a sk + + unsoundPureUpdateKES _ctx _sk _to = Nothing + + -- + -- Key generation + -- + + unsoundPureGenKeyKES seed = + UnsoundPureSignKeySingleKES $! genKeyDSIGN seed + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeySingleKES v) = + VerKeySingleKES $! deriveVerKeyDSIGN v + + unsoundPureSignKeyKESToSoundSignKeyKES = + unsoundPureSignKeyKESToSoundSignKeyKESViaSer + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeySingleKES sk) = + rawSerialiseSignKeyDSIGN sk + rawDeserialiseUnsoundPureSignKeyKES b = + UnsoundPureSignKeySingleKES <$> rawDeserialiseSignKeyDSIGN b + instance (KESAlgorithm (SingleKES d), UnsoundDSIGNMAlgorithm d) => UnsoundKESAlgorithm (SingleKES d) where rawSerialiseSignKeyKES (SignKeySingleKES sk) = @@ -187,4 +218,35 @@ instance DSIGNMAlgorithm d => ToCBOR (SigKES (SingleKES d)) where instance DSIGNMAlgorithm d => FromCBOR (SigKES (SingleKES d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + +-- +-- UnsoundPureSignKey instances +-- + +deriving instance DSIGNAlgorithm d => Show (UnsoundPureSignKeyKES (SingleKES d)) +deriving instance Eq (SignKeyDSIGN d) => Eq (UnsoundPureSignKeyKES (SingleKES d)) + +instance (UnsoundDSIGNMAlgorithm d) => ToCBOR (UnsoundPureSignKeyKES (SingleKES d)) where + toCBOR = encodeUnsoundPureSignKeyKES + encodedSizeExpr _size _skProxy = encodedSignKeyKESSizeExpr (Proxy :: Proxy (SignKeyKES (SingleKES d))) + +instance (UnsoundDSIGNMAlgorithm d) => FromCBOR (UnsoundPureSignKeyKES (SingleKES d)) where + fromCBOR = decodeUnsoundPureSignKeyKES + +instance DSIGNAlgorithm d => NoThunks (UnsoundPureSignKeyKES (SingleKES d)) + +-- +-- Direct ser/deser +-- + +instance (DirectSerialise (SignKeyDSIGNM d)) => DirectSerialise (SignKeyKES (SingleKES d)) where + directSerialise push (SignKeySingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (SignKeyDSIGNM d)) => DirectDeserialise (SignKeyKES (SingleKES d)) where + directDeserialise pull = SignKeySingleKES <$!> directDeserialise pull + +instance (DirectSerialise (VerKeyDSIGN d)) => DirectSerialise (VerKeyKES (SingleKES d)) where + directSerialise push (VerKeySingleKES sk) = directSerialise push sk + +instance (DirectDeserialise (VerKeyDSIGN d)) => DirectDeserialise (VerKeyKES (SingleKES d)) where + directDeserialise pull = VerKeySingleKES <$!> directDeserialise pull diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index 300962d11..f8fb26a40 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -54,6 +54,7 @@ module Cardano.Crypto.KES.Sum ( import Data.Proxy (Proxy(..)) import GHC.Generics (Generic) import qualified Data.ByteString as BS +import qualified Data.ByteString.Internal as BS import Control.Monad (guard, (<$!>)) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -63,12 +64,16 @@ import Cardano.Crypto.Hash.Class import Cardano.Crypto.KES.Class import Cardano.Crypto.KES.Single (SingleKES) import Cardano.Crypto.Util +import Cardano.Crypto.Seed import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.DirectSerialise + import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) - +import Foreign.Ptr (castPtr) -- | A 2^0 period KES type Sum0KES d = SingleKES d @@ -383,4 +388,173 @@ instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSiz instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => FromCBOR (SigKES (SumKES h d)) where fromCBOR = decodeSigKES - {-# INLINE fromCBOR #-} + +-- +-- Unsound pure KES API +-- +instance ( KESAlgorithm (SumKES h d) + , HashAlgorithm h + , UnsoundPureKESAlgorithm d + ) + => UnsoundPureKESAlgorithm (SumKES h d) where + data UnsoundPureSignKeyKES (SumKES h d) = + UnsoundPureSignKeySumKES !(UnsoundPureSignKeyKES d) + !Seed + !(VerKeyKES d) + !(VerKeyKES d) + deriving (Generic) + + unsoundPureSignKES ctxt t a (UnsoundPureSignKeySumKES sk _r_1 vk_0 vk_1) = + SigSumKES sigma vk_0 vk_1 + where + sigma + | t < _T = unsoundPureSignKES ctxt t a sk + | otherwise = unsoundPureSignKES ctxt (t - _T) a sk + + _T = totalPeriodsKES (Proxy :: Proxy d) + + unsoundPureUpdateKES ctx (UnsoundPureSignKeySumKES sk r_1 vk_0 vk_1) t + | t+1 < _T = do + sk' <- unsoundPureUpdateKES ctx sk t + return $! UnsoundPureSignKeySumKES sk' r_1 vk_0 vk_1 + | t+1 == _T = do + let sk' = unsoundPureGenKeyKES r_1 + let r_1' = mkSeedFromBytes (BS.replicate (fromIntegral (seedSizeKES (Proxy @d))) 0) + return $! UnsoundPureSignKeySumKES sk' r_1' vk_0 vk_1 + | otherwise = do + sk' <- unsoundPureUpdateKES ctx sk (t - _T) + return $! UnsoundPureSignKeySumKES sk' r_1 vk_0 vk_1 + where + _T = totalPeriodsKES (Proxy :: Proxy d) + + -- + -- Key generation + -- + + unsoundPureGenKeyKES r = + let (r0, r1) = expandSeed (Proxy @h) r + sk_0 = unsoundPureGenKeyKES r0 + vk_0 = unsoundPureDeriveVerKeyKES sk_0 + sk_1 = unsoundPureGenKeyKES r1 + vk_1 = unsoundPureDeriveVerKeyKES sk_1 + in UnsoundPureSignKeySumKES sk_0 r1 vk_0 vk_1 + + unsoundPureDeriveVerKeyKES (UnsoundPureSignKeySumKES _ _ vk_0 vk_1) = + VerKeySumKES (hashPairOfVKeys (vk_0, vk_1)) + + unsoundPureSignKeyKESToSoundSignKeyKES (UnsoundPureSignKeySumKES sk r_1 vk_0 vk_1) = + SignKeySumKES + <$> unsoundPureSignKeyKESToSoundSignKeyKES sk + <*> (fmap MLockedSeed . mlsbFromByteString . getSeedBytes $ r_1) + <*> pure vk_0 + <*> pure vk_1 + + rawSerialiseUnsoundPureSignKeyKES (UnsoundPureSignKeySumKES sk r_1 vk_0 vk_1) = + let ssk = rawSerialiseUnsoundPureSignKeyKES sk + sr1 = getSeedBytes r_1 + in mconcat + [ ssk + , sr1 + , rawSerialiseVerKeyKES vk_0 + , rawSerialiseVerKeyKES vk_1 + ] + + rawDeserialiseUnsoundPureSignKeyKES b = do + guard (BS.length b == fromIntegral size_total) + sk <- rawDeserialiseUnsoundPureSignKeyKES b_sk + let r = mkSeedFromBytes b_r + vk_0 <- rawDeserialiseVerKeyKES b_vk0 + vk_1 <- rawDeserialiseVerKeyKES b_vk1 + return (UnsoundPureSignKeySumKES sk r vk_0 vk_1) + where + b_sk = slice off_sk size_sk b + b_r = slice off_r size_r b + b_vk0 = slice off_vk0 size_vk b + b_vk1 = slice off_vk1 size_vk b + + size_sk = sizeSignKeyKES (Proxy :: Proxy d) + size_r = seedSizeKES (Proxy :: Proxy d) + size_vk = sizeVerKeyKES (Proxy :: Proxy d) + size_total = sizeSignKeyKES (Proxy :: Proxy (SumKES h d)) + + off_sk = 0 :: Word + off_r = size_sk + off_vk0 = off_r + size_r + off_vk1 = off_vk0 + size_vk + + +-- +-- UnsoundPureSignKey instances +-- + +deriving instance (KESAlgorithm d, Show (UnsoundPureSignKeyKES d)) => Show (UnsoundPureSignKeyKES (SumKES h d)) +deriving instance (KESAlgorithm d, Eq (UnsoundPureSignKeyKES d)) => Eq (UnsoundPureSignKeyKES (SumKES h d)) + +instance ( SizeHash h ~ SeedSizeKES d + , UnsoundPureKESAlgorithm d + , SodiumHashAlgorithm h + , KnownNat (SizeVerKeyKES (SumKES h d)) + , KnownNat (SizeSignKeyKES (SumKES h d)) + , KnownNat (SizeSigKES (SumKES h d)) + ) => ToCBOR (UnsoundPureSignKeyKES (SumKES h d)) where + toCBOR = encodeUnsoundPureSignKeyKES + encodedSizeExpr _size _skProxy = encodedSignKeyKESSizeExpr (Proxy :: Proxy (SignKeyKES (SumKES h d))) + +instance ( SizeHash h ~ SeedSizeKES d + , UnsoundPureKESAlgorithm d + , SodiumHashAlgorithm h + , KnownNat (SizeVerKeyKES (SumKES h d)) + , KnownNat (SizeSignKeyKES (SumKES h d)) + , KnownNat (SizeSigKES (SumKES h d)) + ) => FromCBOR (UnsoundPureSignKeyKES (SumKES h d)) where + fromCBOR = decodeUnsoundPureSignKeyKES + +instance (NoThunks (UnsoundPureSignKeyKES d), KESAlgorithm d) => NoThunks (UnsoundPureSignKeyKES (SumKES h d)) + +-- +-- Direct ser/deser +-- + +instance ( DirectSerialise (SignKeyKES d) + , DirectSerialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectSerialise (SignKeyKES (SumKES h d)) where + directSerialise push (SignKeySumKES sk r vk0 vk1) = do + directSerialise push sk + mlockedSeedUseAsCPtr r $ \ptr -> + push (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + directSerialise push vk0 + directSerialise push vk1 + +instance ( DirectDeserialise (SignKeyKES d) + , DirectDeserialise (VerKeyKES d) + , KESAlgorithm d + ) => DirectDeserialise (SignKeyKES (SumKES h d)) where + directDeserialise pull = do + sk <- directDeserialise pull + + r <- mlockedSeedNew + mlockedSeedUseAsCPtr r $ \ptr -> + pull (castPtr ptr) (fromIntegral $ seedSizeKES (Proxy :: Proxy d)) + + vk0 <- directDeserialise pull + vk1 <- directDeserialise pull + + return $! SignKeySumKES sk r vk0 vk1 + + +instance DirectSerialise (VerKeyKES (SumKES h d)) where + directSerialise push (VerKeySumKES h) = + unpackByteStringCStringLen (hashToBytes h) $ \(ptr, len) -> + push (castPtr ptr) (fromIntegral len) + +instance (HashAlgorithm h) + => DirectDeserialise (VerKeyKES (SumKES h d)) where + directDeserialise pull = do + let len :: Num a => a + len = fromIntegral $ sizeHash (Proxy @h) + fptr <- mallocForeignPtrBytes len + withForeignPtr fptr $ \ptr -> do + pull (castPtr ptr) len + let bs = BS.fromForeignPtr (unsafeRawForeignPtr fptr) 0 len + maybe (error "Invalid hash") return $! VerKeySumKES <$!> hashFromBytes bs diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs index 1ce55c66b..3fe9623ec 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/C.hs @@ -29,6 +29,8 @@ module Cardano.Crypto.Libsodium.C ( c_crypto_sign_ed25519_detached, c_crypto_sign_ed25519_verify_detached, c_crypto_sign_ed25519_sk_to_pk, + -- * RNG + c_sodium_randombytes_buf, -- * Helpers c_sodium_compare, -- * Constants @@ -182,3 +184,6 @@ foreign import capi unsafe "sodium.h crypto_sign_ed25519_sk_to_pk" c_crypto_sign -- -- foreign import capi unsafe "sodium.h sodium_compare" c_sodium_compare :: Ptr a -> Ptr a -> CSize -> IO Int + +-- | @void randombytes_buf(void * const buf, const size_t size);@ +foreign import capi unsafe "sodium/randombytes.h randombytes_buf" c_sodium_randombytes_buf :: Ptr a -> CSize -> IO () diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs index 5fb8c600d..7dab9880f 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs @@ -2,11 +2,13 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} module Cardano.Crypto.Libsodium.MLockedSeed where +import Cardano.Crypto.DirectSerialise import Cardano.Crypto.Libsodium.MLockedBytes ( MLockedSizedBytes, mlsbCopyWith, @@ -20,12 +22,16 @@ import Cardano.Crypto.Libsodium.Memory ( MLockedAllocator, mlockedMalloc, ) +import Cardano.Crypto.Libsodium.C ( + c_sodium_randombytes_buf, + ) import Cardano.Foreign (SizedPtr) import Control.DeepSeq (NFData) import Control.Monad.Class.MonadST (MonadST) +import Data.Proxy (Proxy (..)) import Data.Word (Word8) -import Foreign.Ptr (Ptr) -import GHC.TypeNats (KnownNat) +import Foreign.Ptr (Ptr, castPtr) +import GHC.TypeNats (KnownNat, natVal) import NoThunks.Class (NoThunks) -- | A seed of size @n@, stored in mlocked memory. This is required to prevent @@ -34,6 +40,18 @@ import NoThunks.Class (NoThunks) newtype MLockedSeed n = MLockedSeed {mlockedSeedMLSB :: MLockedSizedBytes n} deriving (NFData, NoThunks) +instance KnownNat n => DirectSerialise (MLockedSeed n) where + directSerialise push seed = + mlockedSeedUseAsCPtr seed $ \ptr -> + push (castPtr ptr) (fromIntegral $ natVal seed) + +instance KnownNat n => DirectDeserialise (MLockedSeed n) where + directDeserialise pull = do + seed <- mlockedSeedNew + mlockedSeedUseAsCPtr seed $ \ptr -> + pull (castPtr ptr) (fromIntegral $ natVal seed) + return seed + withMLockedSeedAsMLSB :: Functor m => (MLockedSizedBytes n -> m (MLockedSizedBytes n)) @@ -66,6 +84,18 @@ mlockedSeedNewZeroWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (ML mlockedSeedNewZeroWith allocator = MLockedSeed <$> mlsbNewZeroWith allocator +mlockedSeedNewRandom :: forall n. (KnownNat n) => IO (MLockedSeed n) +mlockedSeedNewRandom = mlockedSeedNewRandomWith mlockedMalloc + +mlockedSeedNewRandomWith :: forall n. (KnownNat n) => MLockedAllocator IO -> IO (MLockedSeed n) +mlockedSeedNewRandomWith allocator = do + mls <- MLockedSeed <$> mlsbNewZeroWith allocator + mlockedSeedUseAsCPtr mls $ \dst -> do + c_sodium_randombytes_buf dst size + return mls + where + size = fromIntegral $ natVal (Proxy @n) + mlockedSeedFinalize :: (MonadST m) => MLockedSeed n -> m () mlockedSeedFinalize = mlsbFinalize . mlockedSeedMLSB diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index a4405ef5d..cd927cb42 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -25,7 +25,13 @@ module Cardano.Crypto.Libsodium.Memory ( copyMem, allocaBytes, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), + mallocForeignPtrBytes, + withForeignPtr, + -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, ) where diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 68b57392f..345a18342 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -28,12 +28,18 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( mlockedAllocForeignPtrWith, mlockedAllocForeignPtrBytesWith, + -- * 'ForeignPtr' operations, generalized to 'MonadST' + ForeignPtr (..), + mallocForeignPtrBytes, + withForeignPtr, + -- * Unmanaged memory, generalized to 'MonadST' zeroMem, copyMem, allocaBytes, -- * ByteString memory access, generalized to 'MonadST' + unpackByteStringCStringLen, packByteStringCStringLen, -- * Helper @@ -43,22 +49,24 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( import Control.DeepSeq (NFData (..), rwhnf) import Control.Exception (Exception, mask_) import Control.Monad (when, void) -import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadST (MonadST, stToIO) import Control.Monad.Class.MonadThrow (MonadThrow (bracket)) import Control.Monad.ST (RealWorld, ST) -import Control.Monad.ST.Unsafe (unsafeIOToST, unsafeSTToIO) +import Control.Monad.Primitive (touch) +import Control.Monad.ST.Unsafe (unsafeIOToST) import Data.ByteString (ByteString) import qualified Data.ByteString as BS +import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) import Foreign.C.Types (CSize (..)) -import Foreign.Concurrent (newForeignPtr) -import Foreign.ForeignPtr (ForeignPtr, finalizeForeignPtr, touchForeignPtr) +import qualified Foreign.Concurrent as Foreign +import qualified Foreign.ForeignPtr as Foreign hiding (newForeignPtr) +import qualified Foreign.ForeignPtr.Unsafe as Foreign import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) -import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) import Foreign.Storable (Storable (peek), sizeOf, alignment) @@ -66,13 +74,14 @@ import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import System.IO.Unsafe (unsafePerformIO) +import Data.Kind import Cardano.Crypto.Libsodium.C import Cardano.Foreign (c_memset, c_memcpy, SizedPtr (..)) import Cardano.Memory.Pool (initPool, grabNextBlock, Pool) -- | Foreign pointer to securely allocated memory. -newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: ForeignPtr a } +newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: Foreign.ForeignPtr a } deriving NoThunks via OnlyCheckWhnfNamed "MLockedForeignPtr" (MLockedForeignPtr a) instance NFData (MLockedForeignPtr a) where @@ -81,11 +90,11 @@ instance NFData (MLockedForeignPtr a) where withMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> (Ptr a -> m b) -> m b withMLockedForeignPtr (SFP fptr) f = do r <- f (unsafeForeignPtrToPtr fptr) - r <$ unsafeIOToMonadST (touchForeignPtr fptr) + r <$ unsafeIOToMonadST (Foreign.touchForeignPtr fptr) finalizeMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> m () finalizeMLockedForeignPtr (SFP fptr) = - unsafeIOToMonadST $ finalizeForeignPtr fptr + unsafeIOToMonadST $ Foreign.finalizeForeignPtr fptr {-# WARNING traceMLockedForeignPtr "Do not use traceMLockedForeignPtr in production" #-} @@ -103,7 +112,7 @@ makeMLockedPool = do (max 1 . fromIntegral $ 4096 `div` natVal (Proxy @n) `div` 64) (\size -> unsafeIOToST $ mask_ $ do ptr <- sodiumMalloc (fromIntegral size) - newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) + Foreign.newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) ) (\ptr -> do eraseMem (Proxy @n) ptr @@ -159,7 +168,7 @@ mlockedMallocIO size = SFP <$> do | otherwise -> do mask_ $ do ptr <- sodiumMalloc size - newForeignPtr ptr $ do + Foreign.newForeignPtr ptr $ do sodiumFree ptr size sodiumMalloc :: CSize -> IO (Ptr a) @@ -188,9 +197,38 @@ zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size -allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b -allocaBytes size f = - unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) +-- | A 'ForeignPtr' type, generalized to 'MonadST'. The type is tagged with +-- the correct Monad @m@ in order to ensure that foreign pointers created in +-- one ST context can only be used within the same ST context. +newtype ForeignPtr (m :: Type -> Type) a = ForeignPtr { unsafeRawForeignPtr :: Foreign.ForeignPtr a } + +mallocForeignPtrBytes :: (MonadST m) => Int -> m (ForeignPtr m a) +mallocForeignPtrBytes size = + ForeignPtr <$> unsafeIOToMonadST (Foreign.mallocForeignPtrBytes size) + +-- | 'Foreign.withForeignPtr', generalized to 'MonadST'. +-- Caveat: if the monadic action passed to 'withForeignPtr' does not terminate +-- (e.g., 'forever'), the 'ForeignPtr' finalizer may run prematurely. +withForeignPtr :: (MonadST m) => ForeignPtr m a -> (Ptr a -> m b) -> m b +withForeignPtr (ForeignPtr fptr) f = do + result <- f $ Foreign.unsafeForeignPtrToPtr fptr + stToIO $ touch fptr + return result + +allocaBytes :: (MonadThrow m, MonadST m) => Int -> (Ptr a -> m b) -> m b +allocaBytes size action = do + fptr <- mallocForeignPtrBytes size + withForeignPtr fptr action + +-- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' +-- function on it. +unpackByteStringCStringLen :: (MonadThrow m, MonadST m) => ByteString -> (CStringLen -> m a) -> m a +unpackByteStringCStringLen bs f = do + let len = BS.length bs + allocaBytes len $ \buf -> do + unsafeIOToMonadST $ BS.unsafeUseAsCString bs $ \ptr -> do + copyMem buf ptr (fromIntegral len) + f (buf, len) packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString packByteStringCStringLen = @@ -258,7 +296,6 @@ mlockedAllocaWith :: -> (Ptr a -> m b) -> m b mlockedAllocaWith allocator size = - bracket alloc free . flip withMLockedForeignPtr + bracket alloc finalizeMLockedForeignPtr . flip withMLockedForeignPtr where alloc = mlAllocate allocator size - free = finalizeMLockedForeignPtr diff --git a/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs b/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs index b8b13bb73..6f840c596 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs @@ -38,7 +38,7 @@ import Data.Kind (Type) import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Control.Monad.ST.Unsafe (unsafeIOToST) -import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadST (MonadST, stToIO) import Control.Monad.Primitive (primitive_, touch) import Data.Primitive.ByteArray ( ByteArray (..) diff --git a/cardano-crypto-tests/cardano-crypto-tests.cabal b/cardano-crypto-tests/cardano-crypto-tests.cabal index 0dd652e0d..27c809da7 100644 --- a/cardano-crypto-tests/cardano-crypto-tests.cabal +++ b/cardano-crypto-tests/cardano-crypto-tests.cabal @@ -74,7 +74,7 @@ library , contra-tracer ==0.1.0.1 , deepseq , formatting - , io-classes >= 1.1 + , io-classes >= 1.4.0 , mtl , nothunks , pretty-show diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index 1c81dcd46..102bdedf4 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -27,6 +27,7 @@ import Test.QuickCheck ( forAllShow, forAllShrinkShow, ioProperty, + counterexample, ) import Test.Tasty (TestTree, testGroup, adjustOption) import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) @@ -94,6 +95,7 @@ import Cardano.Crypto.DSIGN ( ) import Cardano.Binary (FromCBOR, ToCBOR) import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes) +import Cardano.Crypto.DirectSerialise import Test.Crypto.Util ( Message, prop_raw_serialise, @@ -111,6 +113,9 @@ import Test.Crypto.Util ( showBadInputFor, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, + hexBS, ) import Cardano.Crypto.Libsodium.MLockedSeed @@ -362,6 +367,10 @@ testDSIGNMAlgorithm , FromCBOR (SigDSIGN v) , ContextDSIGN v ~ () , Signable v Message + , DirectSerialise (SignKeyDSIGNM v) + , DirectDeserialise (SignKeyDSIGNM v) + , DirectSerialise (VerKeyDSIGN v) + , DirectDeserialise (VerKeyDSIGN v) ) => Lock -> Proxy v @@ -451,6 +460,36 @@ testDSIGNMAlgorithm lock _ n = sig :: SigDSIGN v <- signDSIGNM () msg sk return $ prop_cbor_direct_vs_class encodeSigDSIGN sig ] + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + sk' <- directDeserialiseFromBS serialized + equals <- sk ==! sk' + forgetSignKeyDSIGNM sk' + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyDSIGN v <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + let raw = rawSerialiseVerKeyDSIGN vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + raw <- rawSerialiseSignKeyDSIGNM sk + return $ direct === raw + ] ] , testGroup "verify" @@ -477,6 +516,24 @@ testDSIGNMAlgorithm lock _ n = ioPropertyWithSK @v lock $ prop_no_thunks_IO . return , testProperty "Sig" $ \(msg :: Message) -> ioPropertyWithSK @v lock $ prop_no_thunks_IO . signDSIGNM () msg + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyDSIGN (Proxy @v)) sk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(SignKeyDSIGNM v) $! direct) + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk <- deriveVerKeyDSIGNM sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyDSIGN (Proxy @v)) vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyDSIGN v) $! direct) ] ] diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index abd5f9528..3d285313b 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -31,17 +31,20 @@ import Data.Set (Set) import qualified Data.Set as Set import Foreign.Ptr (WordPtr) import Data.IORef -import GHC.TypeNats (KnownNat) +import GHC.TypeNats (KnownNat, natVal) -import Control.Tracer +import Control.Monad (void) +import Control.Monad.Class.MonadST import Control.Monad.Class.MonadThrow import Control.Monad.IO.Class (liftIO) -import Control.Monad (void) +import Control.Tracer import Cardano.Crypto.DSIGN hiding (Signable) import Cardano.Crypto.Hash import Cardano.Crypto.KES +import Cardano.Crypto.DirectSerialise (DirectSerialise, DirectDeserialise) import Cardano.Crypto.Util (SignableRepresentation(..)) +import Cardano.Crypto.Seed (mkSeedFromBytes) import Cardano.Crypto.Libsodium import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.PinnedSizedBytes @@ -67,6 +70,8 @@ import Test.Crypto.Util ( noExceptionsThrown, Lock, withLock, + directSerialiseToBS, + directDeserialiseFromBS, ) import Test.Crypto.EqST import Test.Crypto.Instances (withMLockedSeedFromPSB) @@ -95,6 +100,10 @@ tests lock = , testKESAlgorithm @(CompactSum5KES Ed25519DSIGN Blake2b_256) lock "CompactSum5KES" ] +-------------------------------------------------------------------------------- +-- Show and Eq instances +-------------------------------------------------------------------------------- + -- We normally ensure that we avoid naively comparing signing keys by not -- providing instances, but for tests it is fine, so we provide the orphan -- instances here. @@ -137,6 +146,48 @@ instance ( EqST (SignKeyKES d) equalsM (SignKeyCompactSumKES s r v1 v2) (SignKeyCompactSumKES s' r' v1' v2') = (s, r, PureEqST v1, PureEqST v2) ==! (s', r', PureEqST v1', PureEqST v2') +-------------------------------------------------------------------------------- +-- Arbitrary instances +-------------------------------------------------------------------------------- + +genInitialSignKeyKES :: forall k. UnsoundPureKESAlgorithm k => Gen (UnsoundPureSignKeyKES k) +genInitialSignKeyKES = do + bytes <- BS.pack <$> vector (fromIntegral $ seedSizeKES (Proxy @k)) + let seed = mkSeedFromBytes bytes + return $ unsoundPureGenKeyKES seed + +instance (UnsoundPureKESAlgorithm k, Arbitrary (ContextKES k)) => Arbitrary (UnsoundPureSignKeyKES k) where + arbitrary = do + ctx <- arbitrary + let updateTo :: Period -> Period -> UnsoundPureSignKeyKES k -> Maybe (UnsoundPureSignKeyKES k) + updateTo target current sk + | target == current + = Just sk + | target > current + = updateTo target (succ current) =<< unsoundPureUpdateKES ctx sk current + | otherwise + = Nothing + period <- chooseBoundedIntegral (0, totalPeriodsKES (Proxy @k) - 1) + sk0 <- genInitialSignKeyKES + let skMay = updateTo period 0 sk0 + case skMay of + Just sk -> return sk + Nothing -> error "Attempted to generate SignKeyKES evolved beyond max period" + +instance (UnsoundPureKESAlgorithm k, Arbitrary (ContextKES k)) => Arbitrary (VerKeyKES k) where + arbitrary = unsoundPureDeriveVerKeyKES <$> arbitrary + +instance (UnsoundPureKESAlgorithm k, Signable k ByteString, Arbitrary (ContextKES k)) => Arbitrary (SigKES k) where + arbitrary = do + sk <- arbitrary + signable <- BS.pack <$> listOf arbitrary + ctx <- arbitrary + return $ unsoundPureSignKES ctx 0 signable sk + +-------------------------------------------------------------------------------- +-- Tests +-------------------------------------------------------------------------------- + testKESAlloc :: forall v. ( KESAlgorithm v @@ -193,11 +244,18 @@ testKESAlgorithm , FromCBOR (VerKeyKES v) , EqST (SignKeyKES v) -- only monadic EqST for signing keys , Show (SignKeyKES v) -- fake instance defined locally + , Eq (UnsoundPureSignKeyKES v) + , Show (UnsoundPureSignKeyKES v) , ToCBOR (SigKES v) , FromCBOR (SigKES v) , Signable v ~ SignableRepresentation , ContextKES v ~ () , UnsoundKESAlgorithm v + , UnsoundPureKESAlgorithm v + , DirectSerialise (SignKeyKES v) + , DirectSerialise (VerKeyKES v) + , DirectDeserialise (SignKeyKES v) + , DirectDeserialise (VerKeyKES v) ) => Lock -> String @@ -225,9 +283,32 @@ testKESAlgorithm lock n = , testProperty "Sig" $ \seedPSB (msg :: Message) -> ioProperty $ withLock lock $ fmap conjoin $ withAllUpdatesKES @v seedPSB $ \t sk -> do prop_no_thunks_IO (signKES () t msg sk) + + , testProperty "VerKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + prop_no_thunks_IO (return $! direct) + , testProperty "SignKey DirectSerialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + prop_no_thunks_IO (return $! direct) + , testProperty "VerKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) $! vk + prop_no_thunks_IO (directDeserialiseFromBS @IO @(VerKeyKES v) $! direct) + , testProperty "SignKey DirectDeserialise" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + bracket + (directDeserialiseFromBS @IO @(SignKeyKES v) $! direct) + forgetSignKeyKES + (prop_no_thunks_IO . return) ] , testProperty "same VerKey " $ prop_deriveVerKeyKES @v + , testProperty "no forgotten chunks in signkey" $ prop_noErasedBlocksInKey (Proxy @v) , testGroup "serialisation" [ testGroup "raw ser only" @@ -275,6 +356,9 @@ testKESAlgorithm lock n = ioPropertyWithSK @v lock $ \sk -> do sig :: SigKES v <- signKES () 0 msg sk return $ prop_cbor_with encodeSigKES decodeSigKES sig + , testProperty "UnsoundSignKeyKES" $ \seedPSB -> + let sk :: UnsoundPureSignKeyKES v = mkUnsoundPureSignKeyKES seedPSB + in prop_cbor_with encodeUnsoundPureSignKeyKES decodeUnsoundPureSignKeyKES sk ] , testGroup "To/FromCBOR class" @@ -313,6 +397,38 @@ testKESAlgorithm lock n = sig :: SigKES v <- signKES () 0 msg sk return $ prop_cbor_direct_vs_class encodeSigKES sig ] + + , testGroup "DirectSerialise" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + serialized <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + vk' <- directDeserialiseFromBS serialized + return $ vk === vk' + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + serialized <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + equals <- bracket + (directDeserialiseFromBS serialized) + forgetSignKeyKES + (sk ==!) + return $ + counterexample ("Serialized: " ++ hexBS serialized ++ " (length: " ++ show (BS.length serialized) ++ ")") $ + equals + ] + , testGroup "DirectSerialise matches raw" + [ testProperty "VerKey" $ + ioPropertyWithSK @v lock $ \sk -> do + vk :: VerKeyKES v <- deriveVerKeyKES sk + direct <- directSerialiseToBS (fromIntegral $ sizeVerKeyKES (Proxy @v)) vk + let raw = rawSerialiseVerKeyKES vk + return $ direct === raw + , testProperty "SignKey" $ + ioPropertyWithSK @v lock $ \sk -> do + direct <- directSerialiseToBS (fromIntegral $ sizeSignKeyKES (Proxy @v)) sk + raw <- rawSerialiseSignKeyKES sk + return $ direct === raw + ] ] , testGroup "verify" @@ -335,6 +451,12 @@ testKESAlgorithm lock n = -- [ testProperty "key overwritten after forget" $ prop_key_overwritten_after_forget (Proxy @v) -- ] + , testGroup "unsound pure" + [ testProperty "genKey" $ prop_unsoundPureGenKey @v Proxy + , testProperty "updateKES" $ prop_unsoundPureUpdateKES @v Proxy + , testProperty "deriveVerKey" $ prop_unsoundPureDeriveVerKey @v Proxy + , testProperty "sign" $ prop_unsoundPureSign @v Proxy + ] ] -- | Wrap an IO action that requires a 'SignKeyKES' into one that takes an @@ -349,6 +471,12 @@ withSK seedPSB = (withMLockedSeedFromPSB seedPSB genKeyKES) forgetSignKeyKES +mkUnsoundPureSignKeyKES :: UnsoundPureKESAlgorithm v + => PinnedSizedBytes (SeedSizeKES v) -> UnsoundPureSignKeyKES v +mkUnsoundPureSignKeyKES psb = + let seed = mkSeedFromBytes . psbToByteString $ psb + in unsoundPureGenKeyKES seed + -- | Wrap an IO action that requires a 'SignKeyKES' into a 'Property' that -- takes a non-mlocked seed (provided as a 'PinnedSizedBytes' of the -- appropriate size). The key, and the mlocked seed necessary to generate it, @@ -676,3 +804,113 @@ withAllUpdatesKES seedPSB f = withMLockedSeedFromPSB seedPSB $ \seed -> do xs <- go sk' (t + 1) return $ x:xs +withNullSeed :: forall m n a. (MonadThrow m, MonadST m, KnownNat n) => (MLockedSeed n -> m a) -> m a +withNullSeed = bracket + (MLockedSeed <$> mlsbFromByteString (BS.replicate (fromIntegral $ natVal (Proxy @n)) 0)) + mlockedSeedFinalize + +withNullSK :: forall m v a. (KESAlgorithm v, MonadThrow m, MonadST m) + => (SignKeyKES v -> m a) -> m a +withNullSK = bracket + (withNullSeed genKeyKES) + forgetSignKeyKES + + +-- | This test detects whether a sign key contains references to pool-allocated +-- blocks of memory that have been forgotten by the time the key is complete. +-- We do this based on the fact that the pooled allocator erases memory blocks +-- by overwriting them with series of 0xff bytes; thus we cut the serialized +-- key up into chunks of 16 bytes, and if any of those chunks is entirely +-- filled with 0xff bytes, we assume that we're looking at erased memory. +prop_noErasedBlocksInKey + :: forall v. + UnsoundKESAlgorithm v + => DirectSerialise (SignKeyKES v) + => Proxy v + -> Property +prop_noErasedBlocksInKey kesAlgorithm = + ioProperty . withNullSK @IO @v $ \sk -> do + let size :: Int = fromIntegral $ sizeSignKeyKES kesAlgorithm + serialized <- directSerialiseToBS size sk + forgetSignKeyKES sk + return $ counterexample (hexBS serialized) $ not (hasLongRunOfFF serialized) + +hasLongRunOfFF :: ByteString -> Bool +hasLongRunOfFF bs + | BS.length bs < 16 + = False + | otherwise + = let first16 = BS.take 16 bs + remainder = BS.drop 16 bs + in BS.all (== 0xFF) first16 || hasLongRunOfFF remainder + +prop_unsoundPureGenKey :: forall v. + ( UnsoundPureKESAlgorithm v + , EqST (SignKeyKES v) + ) + => Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_unsoundPureGenKey _ seedPSB = ioProperty $ do + let seed = mkSeedFromBytes $ psbToByteString seedPSB + let skPure = unsoundPureGenKeyKES @v seed + withSK seedPSB $ \sk -> do + bracket + (unsoundPureSignKeyKESToSoundSignKeyKES skPure) + forgetSignKeyKES + (equalsM sk) + +prop_unsoundPureDeriveVerKey :: forall v. + ( UnsoundPureKESAlgorithm v + ) + => Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_unsoundPureDeriveVerKey _ seedPSB = ioProperty $ do + let seed = mkSeedFromBytes $ psbToByteString seedPSB + let skPure = unsoundPureGenKeyKES @v seed + vkPure = unsoundPureDeriveVerKeyKES @v skPure + vk <- withSK seedPSB deriveVerKeyKES + return $ vkPure === vk + +prop_unsoundPureUpdateKES :: forall v. + ( UnsoundPureKESAlgorithm v + , ContextKES v ~ () + , EqST (SignKeyKES v) + ) + => Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_unsoundPureUpdateKES _ seedPSB = ioProperty $ do + let seed = mkSeedFromBytes $ psbToByteString seedPSB + let skPure = unsoundPureGenKeyKES @v seed + skPure'Maybe = unsoundPureUpdateKES () skPure 0 + withSK seedPSB $ \sk -> do + bracket + (updateKES () sk 0) + (maybe (return ()) forgetSignKeyKES) $ \sk'Maybe -> do + case skPure'Maybe of + Nothing -> + case sk'Maybe of + Nothing -> return $ property True + Just _ -> return $ counterexample "pure does not update, but should" $ property False + Just skPure' -> + bracket + (unsoundPureSignKeyKESToSoundSignKeyKES skPure') + forgetSignKeyKES $ \sk'' -> + case sk'Maybe of + Nothing -> + return (counterexample "pure updates, but shouldn't" $ property False) + Just sk' -> + property <$> equalsM sk' sk'' + +prop_unsoundPureSign :: forall v. + ( UnsoundPureKESAlgorithm v + , ContextKES v ~ () + , Signable v Message + ) + => Proxy v + -> PinnedSizedBytes (SeedSizeKES v) + -> Message + -> Property +prop_unsoundPureSign _ seedPSB msg = ioProperty $ do + let seed = mkSeedFromBytes $ psbToByteString seedPSB + let skPure = unsoundPureGenKeyKES @v seed + sigPure = unsoundPureSignKES () 0 msg skPure + sig <- withSK seedPSB $ signKES () 0 msg + return $ sigPure === sig + diff --git a/cardano-crypto-tests/src/Test/Crypto/Util.hs b/cardano-crypto-tests/src/Test/Crypto/Util.hs index c0c7d7441..39c91a0ce 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Util.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Util.hs @@ -59,6 +59,10 @@ module Test.Crypto.Util , noExceptionsThrown , doesNotThrow + -- * Direct ser/deser helpers + , directSerialiseToBS + , directDeserialiseFromBS + -- * Error handling , eitherShowError @@ -95,12 +99,18 @@ import Codec.CBOR.Write ( ) import Cardano.Crypto.Seed (Seed, mkSeedFromBytes) import Cardano.Crypto.Util (SignableRepresentation(..)) +import Cardano.Crypto.DirectSerialise import Crypto.Random ( ChaChaDRG , MonadPseudoRandom , drgNewTest , withDRG ) +import Cardano.Crypto.Libsodium.Memory + ( unpackByteStringCStringLen + , packByteStringCStringLen + , allocaBytes + ) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BS8 @@ -130,7 +140,14 @@ import qualified Test.QuickCheck.Gen as Gen import Control.Monad (guard, when) import GHC.TypeLits (Nat, KnownNat, natVal) import Formatting.Buildable (Buildable (..), build) -import Control.Concurrent.Class.MonadMVar (MVar, withMVar, newMVar) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) +import Control.Concurrent.Class.MonadMVar + ( MVar + , withMVar + , newMVar + , newMVar + ) import GHC.Stack (HasCallStack) -------------------------------------------------------------------------------- @@ -375,3 +392,29 @@ mkLock = Lock <$> newMVar () eitherShowError :: (HasCallStack, Show e) => Either e a -> IO a eitherShowError (Left e) = error (show e) eitherShowError (Right a) = return a + +-------------------------------------------------------------------------------- +-- Helpers for direct ser/deser +-------------------------------------------------------------------------------- + +directSerialiseToBS :: forall m a. + DirectSerialise a + => MonadST m + => MonadThrow m + => Int + -> a + -> m ByteString +directSerialiseToBS dstsize val = do + allocaBytes dstsize $ \dst -> do + directSerialiseBufChecked dst dstsize val + packByteStringCStringLen (dst, fromIntegral dstsize) + +directDeserialiseFromBS :: forall m a. + DirectDeserialise a + => MonadST m + => MonadThrow m + => ByteString + -> m a +directDeserialiseFromBS bs = do + unpackByteStringCStringLen bs $ \(src, srcsize) -> do + directDeserialiseBufChecked src srcsize