Skip to content

Update with latest RSA from go stdlib #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 0 additions & 230 deletions blindsign/blindrsa/pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,8 @@ package blindrsa
// This file implements the RSASSA-PSS signature scheme according to RFC 8017.

import (
"bytes"
"crypto"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
"math/big"
)

// Per RFC 8017, Section 9.1
Expand Down Expand Up @@ -132,227 +126,3 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 13. Output EM.
return em, nil
}

func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// See RFC 8017, Section 9.1.2.

hLen := hash.Size()
if sLen == PSSSaltLengthEqualsHash {
sLen = hLen
}
emLen := (emBits + 7) / 8
if emLen != len(em) {
return errors.New("rsa: internal error: inconsistent length")
}

// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
// and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
if hLen != len(mHash) {
fmt.Println("here3", hLen, len(mHash))
return ErrVerification
}

// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
if emLen < hLen+sLen+2 {
fmt.Println("here2")
return ErrVerification
}

// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
if em[emLen-1] != 0xbc {
fmt.Println("here")
return ErrVerification
}

// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
// let H be the next hLen octets.
db := em[:emLen-hLen-1]
h := em[emLen-hLen-1 : emLen-1]

// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
var bitMask byte = 0xff >> (8*emLen - emBits)
if em[0] & ^bitMask != 0 {
fmt.Println("here4")
return ErrVerification
}

// 7. Let dbMask = MGF(H, emLen - hLen - 1).
//
// 8. Let DB = maskedDB \xor dbMask.
mgf1XOR(db, hash, h)

// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
// to zero.
db[0] &= bitMask

// If we don't know the salt length, look for the 0x01 delimiter.
if sLen == PSSSaltLengthAuto {
psLen := bytes.IndexByte(db, 0x01)
if psLen < 0 {
fmt.Println("here5")
return ErrVerification
}
sLen = len(db) - psLen - 1
}

// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
// position is "position 1") does not have hexadecimal value 0x01,
// output "inconsistent" and stop.
psLen := emLen - hLen - sLen - 2
for _, e := range db[:psLen] {
if e != 0x00 {
fmt.Println("here6")
return ErrVerification
}
}
if db[psLen] != 0x01 {
fmt.Println("here7")
return ErrVerification
}

// 11. Let salt be the last sLen octets of DB.
salt := db[len(db)-sLen:]

// 12. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
//
// 13. Let H' = Hash(M'), an octet string of length hLen.
var prefix [8]byte
hash.Write(prefix[:])
hash.Write(mHash)
hash.Write(salt)

h0 := hash.Sum(nil)

// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
if !bytes.Equal(h0, h) { // TODO: constant time?
fmt.Println("here8")
return ErrVerification
}
return nil
}

// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return nil, err
}
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
return nil, err
}
s := make([]byte, priv.Size())
copyWithLeftPad(s, c.Bytes())
return s, nil
}

const (
// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
// as possible when signing, and to be auto-detected when verifying.
PSSSaltLengthAuto = 0
// PSSSaltLengthEqualsHash causes the salt length to equal the length
// of the hash used in the signature.
PSSSaltLengthEqualsHash = -1
)

// PSSOptions contains options for creating and verifying PSS signatures.
type PSSOptions struct {
// SaltLength controls the length of the salt used in the PSS
// signature. It can either be a number of bytes, or one of the special
// PSSSaltLength constants.
SaltLength int

// Hash is the hash function used to generate the message digest. If not
// zero, it overrides the hash function passed to SignPSS. It's required
// when using PrivateKey.Sign.
Hash crypto.Hash
}

// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
func (opts *PSSOptions) HashFunc() crypto.Hash {
return opts.Hash
}

func (opts *PSSOptions) saltLength() int {
if opts == nil {
return PSSSaltLengthAuto
}
return opts.SaltLength
}

// SignPSS calculates the signature of digest using PSS.
//
// digest must be the result of hashing the input message using the given hash
// function. The opts argument may be nil, in which case sensible defaults are
// used. If opts.Hash is set, it overrides hash.
func SignPSS(rand io.Reader, priv *rsa.PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
if opts != nil && opts.Hash != 0 {
hash = opts.Hash
}

saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
saltLength = priv.Size() - 2 - hash.Size()
case PSSSaltLengthEqualsHash:
saltLength = hash.Size()
}

salt := make([]byte, saltLength)
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
}
return signPSSWithSalt(rand, priv, hash, digest, salt)
}

// VerifyPSS verifies a PSS signature.
//
// A valid signature is indicated by returning a nil error. digest must be the
// result of hashing the input message using the given hash function. The opts
// argument may be nil, in which case sensible defaults are used. opts.Hash is
// ignored.
func VerifyPSS(pub *rsa.PublicKey, hash hash.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
if len(sig) != pub.Size() {
fmt.Println("1")
return ErrVerification
}
s := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, s)
emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
emBytes := m.Bytes()
if m.BitLen() > emLen*8 {
fmt.Println("2")
return ErrVerification
}

em := make([]byte, emLen)
copyWithLeftPad(em, emBytes)

return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash)
}

// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
// needed.
func copyWithLeftPad(dest, src []byte) {
numPaddingBytes := len(dest) - len(src)
for i := 0; i < numPaddingBytes; i++ {
dest[i] = 0
}
copy(dest[numPaddingBytes:], src)
}
17 changes: 9 additions & 8 deletions blindsign/blindrsa/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ var (
bigOne = big.NewInt(1)
)

// ErrVerification represents a failure to verify a signature.
// It is deliberately vague to avoid adaptive attacks.
var ErrVerification = errors.New("crypto/rsa: verification error")

// incCounter increments a four byte, big-endian counter.
func incCounter(c *[4]byte) {
if c[3]++; c[3] != 0 {
Expand Down Expand Up @@ -87,12 +83,17 @@ func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int {
return c
}

// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
// random source is given, RSA blinding is used.
func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) {
// TODO(agl): can we get away with reusing blinds?
if c.Cmp(priv.N) > 0 {
err = rsa.ErrDecryption
return
}
if priv.N.Sign() == 0 {
return nil, rsa.ErrDecryption
}

var ir *big.Int
if random != nil {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem too important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it shuffles the random source in case a user-provided source is provided.
in this package, the random source is always crypto/rand, which it's assumed to be secure.

Expand All @@ -102,7 +103,7 @@ func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, er
// by multiplying by the multiplicative inverse of r.

var r *big.Int

ir = new(big.Int)
for {
r, err = rand.Int(random, priv.N)
if err != nil {
Expand All @@ -111,13 +112,13 @@ func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, er
if r.Cmp(bigZero) == 0 {
r = bigOne
}
ir = new(big.Int).ModInverse(r, priv.N)
if ir != nil {
ok := ir.ModInverse(r, priv.N)
if ok != nil {
break
}
}
bigE := big.NewInt(int64(priv.E))
rpowe := new(big.Int).Exp(r, bigE, priv.N)
rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0
cCopy := new(big.Int).Set(c)
cCopy.Mul(cCopy, rpowe)
cCopy.Mod(cCopy, priv.N)
Expand Down