| package hrss |
| |
| import ( |
| "crypto/hmac" |
| "crypto/sha256" |
| "crypto/subtle" |
| "encoding/binary" |
| "io" |
| "math/bits" |
| ) |
| |
| const ( |
| PublicKeySize = modQBytes |
| CiphertextSize = modQBytes |
| ) |
| |
| const ( |
| N = 701 |
| Q = 8192 |
| mod3Bytes = 140 |
| modQBytes = 1138 |
| ) |
| |
| const ( |
| bitsPerWord = bits.UintSize |
| wordsPerPoly = (N + bitsPerWord - 1) / bitsPerWord |
| fullWordsPerPoly = N / bitsPerWord |
| bitsInLastWord = N % bitsPerWord |
| ) |
| |
| // poly3 represents a degree-N polynomial over GF(3). Each coefficient is |
| // bitsliced across the |s| and |a| arrays, like this: |
| // |
| // s | a | value |
| // ----------------- |
| // 0 | 0 | 0 |
| // 0 | 1 | 1 |
| // 1 | 0 | 2 (aka -1) |
| // 1 | 1 | <invalid> |
| // |
| // ('s' is for sign, and 'a' is just a letter.) |
| // |
| // Once bitsliced as such, the following circuits can be used to implement |
| // addition and multiplication mod 3: |
| // |
| // (s3, a3) = (s1, a1) × (s2, a2) |
| // s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2) |
| // a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2) |
| // |
| // (s3, a3) = (s1, a1) + (s2, a2) |
| // t1 = ~(s1 ∨ a1) |
| // t2 = ~(s2 ∨ a2) |
| // s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1) |
| // a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1) |
| // |
| // Negating a value just involves swapping s and a. |
| type poly3 struct { |
| s [wordsPerPoly]uint |
| a [wordsPerPoly]uint |
| } |
| |
| func (p *poly3) trim() { |
| p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 |
| p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 |
| } |
| |
| func (p *poly3) zero() { |
| for i := range p.a { |
| p.s[i] = 0 |
| p.a[i] = 0 |
| } |
| } |
| |
| func (p *poly3) fromDiscrete(in *poly) { |
| var shift uint |
| s := p.s[:] |
| a := p.a[:] |
| s[0] = 0 |
| a[0] = 0 |
| |
| for _, v := range in { |
| s[0] >>= 1 |
| s[0] |= uint((v>>1)&1) << (bitsPerWord - 1) |
| a[0] >>= 1 |
| a[0] |= uint(v&1) << (bitsPerWord - 1) |
| shift++ |
| if shift == bitsPerWord { |
| s = s[1:] |
| a = a[1:] |
| s[0] = 0 |
| a[0] = 0 |
| shift = 0 |
| } |
| } |
| |
| a[0] >>= bitsPerWord - shift |
| s[0] >>= bitsPerWord - shift |
| } |
| |
| func (p *poly3) fromModQ(in *poly) int { |
| var shift uint |
| s := p.s[:] |
| a := p.a[:] |
| s[0] = 0 |
| a[0] = 0 |
| ok := 1 |
| |
| for _, v := range in { |
| vMod3, vOk := modQToMod3(v) |
| ok &= vOk |
| |
| s[0] >>= 1 |
| s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1) |
| a[0] >>= 1 |
| a[0] |= uint(vMod3&1) << (bitsPerWord - 1) |
| shift++ |
| if shift == bitsPerWord { |
| s = s[1:] |
| a = a[1:] |
| s[0] = 0 |
| a[0] = 0 |
| shift = 0 |
| } |
| } |
| |
| a[0] >>= bitsPerWord - shift |
| s[0] >>= bitsPerWord - shift |
| |
| return ok |
| } |
| |
| func (p *poly3) fromDiscreteMod3(in *poly) { |
| var shift uint |
| s := p.s[:] |
| a := p.a[:] |
| s[0] = 0 |
| a[0] = 0 |
| |
| for _, v := range in { |
| // This duplicates the 13th bit upwards to the top of the |
| // uint16, essentially treating it as a sign bit and converting |
| // into a signed int16. The signed value is reduced mod 3, |
| // yeilding {-2, -1, 0, 1, 2}. |
| v = uint16((int16(v<<3)>>3)%3) & 7 |
| |
| // We want to map v thus: |
| // {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom |
| // three bits and then the constants below, when shifted by |
| // those three bits, perform the required mapping. |
| s[0] >>= 1 |
| s[0] |= (0xbc >> v) << (bitsPerWord - 1) |
| a[0] >>= 1 |
| a[0] |= (0x7a >> v) << (bitsPerWord - 1) |
| shift++ |
| if shift == bitsPerWord { |
| s = s[1:] |
| a = a[1:] |
| s[0] = 0 |
| a[0] = 0 |
| shift = 0 |
| } |
| } |
| |
| a[0] >>= bitsPerWord - shift |
| s[0] >>= bitsPerWord - shift |
| } |
| |
| func (p *poly3) marshal(out []byte) { |
| s := p.s[:] |
| a := p.a[:] |
| sw := s[0] |
| aw := a[0] |
| var shift int |
| |
| for i := 0; i < 700; i += 5 { |
| acc, scale := 0, 1 |
| for j := 0; j < 5; j++ { |
| v := int(aw&1) | int(sw&1)<<1 |
| acc += scale * v |
| scale *= 3 |
| |
| shift++ |
| if shift == bitsPerWord { |
| s = s[1:] |
| a = a[1:] |
| sw = s[0] |
| aw = a[0] |
| shift = 0 |
| } else { |
| sw >>= 1 |
| aw >>= 1 |
| } |
| } |
| |
| out[0] = byte(acc) |
| out = out[1:] |
| } |
| } |
| |
| func (p *poly) fromMod2(in *poly2) { |
| var shift uint |
| words := in[:] |
| word := words[0] |
| |
| for i := range p { |
| p[i] = uint16(word & 1) |
| word >>= 1 |
| shift++ |
| if shift == bitsPerWord { |
| words = words[1:] |
| word = words[0] |
| shift = 0 |
| } |
| } |
| } |
| |
| func (p *poly) fromMod3(in *poly3) { |
| var shift uint |
| s := in.s[:] |
| a := in.a[:] |
| sw := s[0] |
| aw := a[0] |
| |
| for i := range p { |
| p[i] = uint16(aw&1 | (sw&1)<<1) |
| aw >>= 1 |
| sw >>= 1 |
| shift++ |
| if shift == bitsPerWord { |
| a = a[1:] |
| s = s[1:] |
| aw = a[0] |
| sw = s[0] |
| shift = 0 |
| } |
| } |
| } |
| |
| func (p *poly) fromMod3ToModQ(in *poly3) { |
| var shift uint |
| s := in.s[:] |
| a := in.a[:] |
| sw := s[0] |
| aw := a[0] |
| |
| for i := range p { |
| p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1)) |
| aw >>= 1 |
| sw >>= 1 |
| shift++ |
| if shift == bitsPerWord { |
| a = a[1:] |
| s = s[1:] |
| aw = a[0] |
| sw = s[0] |
| shift = 0 |
| } |
| } |
| } |
| |
| func lsbToAll(v uint) uint { |
| return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1)) |
| } |
| |
| func (p *poly3) mulConst(ms, ma uint) { |
| ms = lsbToAll(ms) |
| ma = lsbToAll(ma) |
| |
| for i := range p.a { |
| p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i]) |
| } |
| } |
| |
| func cmovWords(out, in *[wordsPerPoly]uint, mov uint) { |
| for i := range out { |
| out[i] = (out[i] & ^mov) | (in[i] & mov) |
| } |
| } |
| |
| func rotWords(out, in *[wordsPerPoly]uint, bits uint) { |
| start := bits / bitsPerWord |
| n := (N - bits) / bitsPerWord |
| |
| for i := uint(0); i < n; i++ { |
| out[i] = in[start+i] |
| } |
| |
| carry := in[wordsPerPoly-1] |
| |
| for i := uint(0); i < start; i++ { |
| out[n+i] = carry | in[i]<<bitsInLastWord |
| carry = in[i] >> (bitsPerWord - bitsInLastWord) |
| } |
| |
| out[wordsPerPoly-1] = carry |
| } |
| |
| // rotBits right-rotates the bits in |in|. bits must be a non-zero power of two |
| // and less than bitsPerWord. |
| func rotBits(out, in *[wordsPerPoly]uint, bits uint) { |
| if (bits == 0 || (bits & (bits - 1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2) { |
| panic("internal error"); |
| } |
| |
| carry := in[wordsPerPoly-1] << (bitsPerWord - bits) |
| |
| for i := wordsPerPoly - 2; i >= 0; i-- { |
| out[i] = carry | in[i]>>bits |
| carry = in[i] << (bitsPerWord - bits) |
| } |
| |
| out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits |
| } |
| |
| func (p *poly3) rotWords(bits uint, in *poly3) { |
| rotWords(&p.s, &in.s, bits) |
| rotWords(&p.a, &in.a, bits) |
| } |
| |
| func (p *poly3) rotBits(bits uint, in *poly3) { |
| rotBits(&p.s, &in.s, bits) |
| rotBits(&p.a, &in.a, bits) |
| } |
| |
| func (p *poly3) cmov(in *poly3, mov uint) { |
| cmovWords(&p.s, &in.s, mov) |
| cmovWords(&p.a, &in.a, mov) |
| } |
| |
| func (p *poly3) rot(bits uint) { |
| if bits > N { |
| panic("invalid") |
| } |
| var shifted poly3 |
| |
| shift := uint(9) |
| for ; (1 << shift) >= bitsPerWord; shift-- { |
| shifted.rotWords(1<<shift, p) |
| p.cmov(&shifted, lsbToAll(bits>>shift)) |
| } |
| for ; shift < 9; shift-- { |
| shifted.rotBits(1<<shift, p) |
| p.cmov(&shifted, lsbToAll(bits>>shift)) |
| } |
| } |
| |
| func (p *poly3) fmadd(ms, ma uint, in *poly3) { |
| ms = lsbToAll(ms) |
| ma = lsbToAll(ma) |
| |
| for i := range p.a { |
| products := (ma & in.s[i]) ^ (ms & in.a[i]) |
| producta := (ma & in.a[i]) ^ (ms & in.s[i]) |
| |
| ns1Ana1 := ^p.s[i] & ^p.a[i] |
| ns2Ana2 := ^products & ^producta |
| |
| p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2) |
| } |
| } |
| |
| func (p *poly3) modPhiN() { |
| factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) |
| factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) |
| ns2Ana2 := ^factors & ^factora |
| |
| for i := range p.s { |
| ns1Ana1 := ^p.s[i] & ^p.a[i] |
| p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2) |
| } |
| } |
| |
| func (p *poly3) cswap(other *poly3, swap uint) { |
| for i := range p.s { |
| sums := swap & (p.s[i] ^ other.s[i]) |
| p.s[i] ^= sums |
| other.s[i] ^= sums |
| |
| suma := swap & (p.a[i] ^ other.a[i]) |
| p.a[i] ^= suma |
| other.a[i] ^= suma |
| } |
| } |
| |
| func (p *poly3) mulx() { |
| carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 |
| carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 |
| |
| for i := range p.s { |
| outCarrys := p.s[i] >> (bitsPerWord - 1) |
| outCarrya := p.a[i] >> (bitsPerWord - 1) |
| p.s[i] <<= 1 |
| p.a[i] <<= 1 |
| p.s[i] |= carrys |
| p.a[i] |= carrya |
| carrys = outCarrys |
| carrya = outCarrya |
| } |
| } |
| |
| func (p *poly3) divx() { |
| var carrys, carrya uint |
| |
| for i := len(p.s) - 1; i >= 0; i-- { |
| outCarrys := p.s[i] & 1 |
| outCarrya := p.a[i] & 1 |
| p.s[i] >>= 1 |
| p.a[i] >>= 1 |
| p.s[i] |= carrys << (bitsPerWord - 1) |
| p.a[i] |= carrya << (bitsPerWord - 1) |
| carrys = outCarrys |
| carrya = outCarrya |
| } |
| } |
| |
| type poly2 [wordsPerPoly]uint |
| |
| func (p *poly2) fromDiscrete(in *poly) { |
| var shift uint |
| words := p[:] |
| words[0] = 0 |
| |
| for _, v := range in { |
| words[0] >>= 1 |
| words[0] |= uint(v&1) << (bitsPerWord - 1) |
| shift++ |
| if shift == bitsPerWord { |
| words = words[1:] |
| words[0] = 0 |
| shift = 0 |
| } |
| } |
| |
| words[0] >>= bitsPerWord - shift |
| } |
| |
| func (p *poly2) setPhiN() { |
| for i := range p { |
| p[i] = ^uint(0) |
| } |
| p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 |
| } |
| |
| func (p *poly2) cswap(other *poly2, swap uint) { |
| for i := range p { |
| sum := swap & (p[i] ^ other[i]) |
| p[i] ^= sum |
| other[i] ^= sum |
| } |
| } |
| |
| func (p *poly2) fmadd(m uint, in *poly2) { |
| m = ^(m - 1) |
| |
| for i := range p { |
| p[i] ^= in[i] & m |
| } |
| } |
| |
| func (p *poly2) lshift1() { |
| var carry uint |
| for i := range p { |
| nextCarry := p[i] >> (bitsPerWord - 1) |
| p[i] <<= 1 |
| p[i] |= carry |
| carry = nextCarry |
| } |
| } |
| |
| func (p *poly2) rshift1() { |
| var carry uint |
| for i := len(p) - 1; i >= 0; i-- { |
| nextCarry := p[i] & 1 |
| p[i] >>= 1 |
| p[i] |= carry << (bitsPerWord - 1) |
| carry = nextCarry |
| } |
| } |
| |
| func (p *poly2) rot(bits uint) { |
| if bits > N { |
| panic("invalid") |
| } |
| var shifted [wordsPerPoly]uint |
| out := (*[wordsPerPoly]uint)(p) |
| |
| shift := uint(9) |
| for ; (1 << shift) >= bitsPerWord; shift-- { |
| rotWords(&shifted, out, 1<<shift) |
| cmovWords(out, &shifted, lsbToAll(bits>>shift)) |
| } |
| for ; shift < 9; shift-- { |
| rotBits(&shifted, out, 1<<shift) |
| cmovWords(out, &shifted, lsbToAll(bits>>shift)) |
| } |
| } |
| |
| type poly [N]uint16 |
| |
| func (in *poly) marshal(out []byte) { |
| p := in[:] |
| |
| for len(p) >= 8 { |
| out[0] = byte(p[0]) |
| out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) |
| out[2] = byte(p[1] >> 3) |
| out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) |
| out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) |
| out[5] = byte(p[3] >> 1) |
| out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4) |
| out[7] = byte(p[4] >> 4) |
| out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1) |
| out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6) |
| out[10] = byte(p[6] >> 2) |
| out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3) |
| out[12] = byte(p[7] >> 5) |
| |
| p = p[8:] |
| out = out[13:] |
| } |
| |
| // There are four remaining values. |
| out[0] = byte(p[0]) |
| out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) |
| out[2] = byte(p[1] >> 3) |
| out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) |
| out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) |
| out[5] = byte(p[3] >> 1) |
| out[6] = byte(p[3] >> 9) |
| } |
| |
| func (out *poly) unmarshal(in []byte) bool { |
| p := out[:] |
| for i := 0; i < 87; i++ { |
| p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 |
| p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 |
| p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 |
| p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 |
| p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12 |
| p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7 |
| p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10 |
| p[7] = uint16(in[11]>>3) | uint16(in[12])<<5 |
| |
| p = p[8:] |
| in = in[13:] |
| } |
| |
| // There are four coefficients left over |
| p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 |
| p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 |
| p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 |
| p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 |
| |
| if in[6]&0xf0 != 0 { |
| return false |
| } |
| |
| out[N-1] = 0 |
| var top int |
| for _, v := range out { |
| top += int(v) |
| } |
| |
| out[N-1] = uint16(-top) % Q |
| return true |
| } |
| |
| func (in *poly) marshalS3(out []byte) { |
| p := in[:] |
| for len(p) >= 5 { |
| out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81) |
| out = out[1:] |
| p = p[5:] |
| } |
| } |
| |
| func (out *poly) unmarshalS3(in []byte) bool { |
| p := out[:] |
| for i := 0; i < 140; i++ { |
| c := in[0] |
| if c >= 243 { |
| return false |
| } |
| p[0] = uint16(c % 3) |
| p[1] = uint16((c / 3) % 3) |
| p[2] = uint16((c / 9) % 3) |
| p[3] = uint16((c / 27) % 3) |
| p[4] = uint16((c / 81) % 3) |
| |
| p = p[5:] |
| in = in[1:] |
| } |
| |
| out[N-1] = 0 |
| return true |
| } |
| |
| func (p *poly) modPhiN() { |
| for i := range p { |
| p[i] = (p[i] + Q - p[N-1]) % Q |
| } |
| } |
| |
| func (out *poly) shortSample(in []byte) { |
| // b a result |
| // 00 00 00 |
| // 00 01 01 |
| // 00 10 10 |
| // 00 11 11 |
| // 01 00 10 |
| // 01 01 00 |
| // 01 10 01 |
| // 01 11 11 |
| // 10 00 01 |
| // 10 01 10 |
| // 10 10 00 |
| // 10 11 11 |
| // 11 00 11 |
| // 11 01 11 |
| // 11 10 11 |
| // 11 11 11 |
| |
| // 1111 1111 1100 1001 1101 0010 1110 0100 |
| // f f c 9 d 2 e 4 |
| const lookup = uint32(0xffc9d2e4) |
| |
| p := out[:] |
| for i := 0; i < 87; i++ { |
| v := binary.LittleEndian.Uint32(in) |
| v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) |
| for j := 0; j < 8; j++ { |
| p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) |
| v2 >>= 4 |
| } |
| p = p[8:] |
| in = in[4:] |
| } |
| |
| // There are four values remaining. |
| v := binary.LittleEndian.Uint32(in) |
| v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) |
| for j := 0; j < 4; j++ { |
| p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) |
| v2 >>= 4 |
| } |
| |
| out[N-1] = 0 |
| } |
| |
| func (out *poly) shortSamplePlus(in []byte) { |
| out.shortSample(in) |
| |
| var sum uint16 |
| for i := 0; i < N-1; i++ { |
| sum += mod3ResultToModQ(out[i] * out[i+1]) |
| } |
| |
| scale := 1 + (1 & (sum >> 12)) |
| for i := 0; i < len(out); i += 2 { |
| out[i] = (out[i] * scale) % 3 |
| } |
| } |
| |
| func mul(out, scratch, a, b []uint16) { |
| const schoolbookLimit = 32 |
| if len(a) < schoolbookLimit { |
| for i := 0; i < len(a)*2; i++ { |
| out[i] = 0 |
| } |
| for i := range a { |
| for j := range b { |
| out[i+j] += a[i] * b[j] |
| } |
| } |
| return |
| } |
| |
| lowLen := len(a) / 2 |
| highLen := len(a) - lowLen |
| aLow, aHigh := a[:lowLen], a[lowLen:] |
| bLow, bHigh := b[:lowLen], b[lowLen:] |
| |
| for i := 0; i < lowLen; i++ { |
| out[i] = aHigh[i] + aLow[i] |
| } |
| if highLen != lowLen { |
| out[lowLen] = aHigh[lowLen] |
| } |
| |
| for i := 0; i < lowLen; i++ { |
| out[highLen+i] = bHigh[i] + bLow[i] |
| } |
| if highLen != lowLen { |
| out[highLen+lowLen] = bHigh[lowLen] |
| } |
| |
| mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2]) |
| mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh) |
| mul(out, scratch[2*highLen:], aLow, bLow) |
| |
| for i := 0; i < lowLen*2; i++ { |
| scratch[i] -= out[i] + out[lowLen*2+i] |
| } |
| if lowLen != highLen { |
| scratch[lowLen*2] -= out[lowLen*4] |
| } |
| |
| for i := 0; i < 2*highLen; i++ { |
| out[lowLen+i] += scratch[i] |
| } |
| } |
| |
| func (out *poly) mul(a, b *poly) { |
| var prod, scratch [2 * N]uint16 |
| mul(prod[:], scratch[:], a[:], b[:]) |
| for i := range out { |
| out[i] = (prod[i] + prod[i+N]) % Q |
| } |
| } |
| |
| func (p3 *poly3) mulMod3(x, y *poly3) { |
| // (𝑥^n - 1) is a multiple of Φ(N) so we can work mod (𝑥^n - 1) here and |
| // (reduce mod Φ(N) afterwards. |
| x3 := *x |
| y3 := *y |
| s := x3.s[:] |
| a := x3.a[:] |
| sw := s[0] |
| aw := a[0] |
| p3.zero() |
| var shift uint |
| for i := 0; i < N; i++ { |
| p3.fmadd(sw, aw, &y3) |
| sw >>= 1 |
| aw >>= 1 |
| shift++ |
| if shift == bitsPerWord { |
| s = s[1:] |
| a = a[1:] |
| sw = s[0] |
| aw = a[0] |
| shift = 0 |
| } |
| y3.mulx() |
| } |
| p3.modPhiN() |
| } |
| |
| // mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff} |
| // The case of n == 3 should never happen but is included so that modQToMod3 |
| // can easily catch invalid inputs. |
| func mod3ToModQ(n uint16) uint16 { |
| return uint16(uint64(0xffff1fff00010000) >> (16 * n)) |
| } |
| |
| // modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int |
| // which is one if the input is in range and zero otherwise. |
| func modQToMod3(n uint16) (uint16, int) { |
| result := (n&3 - (n>>1)&1) |
| return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n)) |
| } |
| |
| // mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1} |
| func mod3ResultToModQ(n uint16) uint16 { |
| return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1) |
| //shift := (uint(0x324) >> (2 * n)) & 3 |
| //return uint16(uint64(0x00011fff00010000) >> (16 * shift)) |
| } |
| |
| // mulXMinus1 sets out to a×(𝑥 - 1) mod (𝑥^n - 1) |
| func (out *poly) mulXMinus1() { |
| // Multiplying by (𝑥 - 1) means negating each coefficient and adding in |
| // the value of the previous one. |
| origOut700 := out[700] |
| |
| for i := N - 1; i > 0; i-- { |
| out[i] = (Q - out[i] + out[i-1]) % Q |
| } |
| out[0] = (Q - out[0] + origOut700) % Q |
| } |
| |
| func (out *poly) lift(a *poly) { |
| // We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the |
| // Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime). |
| |
| // 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up: |
| // |
| // R.<x> = PolynomialRing(GF(3)…) |
| // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n)) |
| // list(inv)[:15] |
| // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2] |
| // |
| // This three-element pattern of coefficients repeats for the whole |
| // polynomial. |
| // |
| // Next define the overbar operator such that z̅ = z[0] + |
| // reverse(z[1:]). (Index zero of a polynomial here is the coefficient |
| // of the constant term. So index one is the coefficient of 𝑥 and so |
| // on.) |
| // |
| // A less odd way to define this is to see that z̅ negates the indexes, |
| // so z̅[0] = z[-0], z̅[1] = z[-1] and so on. |
| // |
| // The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = <v, |
| // z̅>, vz[1] = <v, 𝑥z̅>, …. (Where <a, b> is the inner product: the sum |
| // of the point-wise products.) Although we calculated the inverse mod |
| // Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end. |
| // (That's because (𝑥^N - 1) is a multiple of Φ(N).) |
| // |
| // When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation |
| // of the list of coefficients. |
| // |
| // Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like: |
| // |
| // def reverse(xs): |
| // suffix = list(xs[1:]) |
| // suffix.reverse() |
| // return [xs[0]] + suffix |
| // |
| // def rotate(xs): |
| // return [xs[-1]] + xs[:-1] |
| // |
| // zoverbar = reverse(list(inv) + [0]) |
| // xzoverbar = rotate(reverse(list(inv) + [0])) |
| // x2zoverbar = rotate(rotate(reverse(list(inv) + [0]))) |
| // |
| // zoverbar[:15] |
| // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1] |
| // xzoverbar[:15] |
| // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0] |
| // x2zoverbar[:15] |
| // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] |
| // |
| // (For a formula for z̅, see lemma two of appendix B.) |
| // |
| // After the first three elements have been taken care of, all then have |
| // a repeating three-element cycle. The next value (𝑥^3z̅) involves |
| // three rotations of the first pattern, thus the three-element cycle |
| // lines up. However, the discontinuity in the first three elements |
| // obviously moves to a different position. Consider the difference |
| // between 𝑥^3z̅ and z̅: |
| // |
| // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15] |
| // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| // |
| // This pattern of differences is the same for all elements, although it |
| // obviously moves right with the rotations. |
| // |
| // From this, we reach algorithm eight of appendix B. |
| |
| // Handle the first three elements of the inner products. |
| out[0] = a[0] + a[2] |
| out[1] = a[1] |
| out[2] = 2*a[0] + a[2] |
| |
| // Use the repeating pattern to complete the first three inner products. |
| for i := 3; i < 699; i += 3 { |
| out[0] += 2*a[i] + a[i+2] |
| out[1] += a[i] + 2*a[i+1] |
| out[2] += a[i+1] + 2*a[i+2] |
| } |
| |
| // Handle the fact that the three-element pattern doesn't fill the |
| // polynomial exactly (since 701 isn't a multiple of three). |
| out[2] += a[700] |
| out[0] += 2 * a[699] |
| out[1] += a[699] + 2*a[700] |
| |
| out[0] = out[0] % 3 |
| out[1] = out[1] % 3 |
| out[2] = out[2] % 3 |
| |
| // Calculate the remaining inner products by taking advantage of the |
| // fact that the pattern repeats every three cycles and the pattern of |
| // differences is moves with the rotation. |
| for i := 3; i < N; i++ { |
| // Add twice something is the same as subtracting when working |
| // mod 3. Doing it this way avoids underflow. Underflow is bad |
| // because "% 3" doesn't work correctly for negative numbers |
| // here since underflow will wrap to 2^16-1 and 2^16 isn't a |
| // multiple of three. |
| out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3 |
| } |
| |
| // Reduce mod Φ(N) by subtracting a multiple of out[700] from every |
| // element and convert to mod Q. (See above about adding twice as |
| // subtraction.) |
| v := out[700] * 2 |
| for i := range out { |
| out[i] = mod3ToModQ((out[i] + v) % 3) |
| } |
| |
| out.mulXMinus1() |
| } |
| |
| func (a *poly) cswap(b *poly, swap uint16) { |
| for i := range a { |
| sum := swap & (a[i] ^ b[i]) |
| a[i] ^= sum |
| b[i] ^= sum |
| } |
| } |
| |
| func lt(a, b uint) uint { |
| if a < b { |
| return ^uint(0) |
| } |
| return 0 |
| } |
| |
| func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) { |
| s3 = (a1 & s2) ^ (s1 & a2) |
| a3 = (a1 & a2) ^ (s1 & s2) |
| return |
| } |
| |
| func (out *poly3) invertMod3(in *poly3) { |
| // This algorithm follows algorithm 10 in the paper. (Although note that |
| // the paper appears to have a bug: k should start at zero, not one.) |
| // The best explanation for why it works is in the "Why it works" |
| // section of |
| // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. |
| var k uint |
| degF, degG := uint(N-1), uint(N-1) |
| |
| var b, c, g poly3 |
| f := *in |
| |
| for i := range g.a { |
| g.a[i] = ^uint(0) |
| } |
| |
| b.a[0] = 1 |
| |
| var f0s, f0a uint |
| stillGoing := ^uint(0) |
| for i := 0; i < 2*(N-1)-1; i++ { |
| ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0]) |
| ss, sa = sa&stillGoing&1, ss&stillGoing&1 |
| shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG) |
| f.cswap(&g, shouldSwap) |
| b.cswap(&c, shouldSwap) |
| degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) |
| f.fmadd(ss, sa, &g) |
| b.fmadd(ss, sa, &c) |
| |
| f.divx() |
| f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 |
| f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 |
| c.mulx() |
| c.s[0] &= ^uint(1) |
| c.a[0] &= ^uint(1) |
| |
| degF-- |
| k += 1 & stillGoing |
| f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s) |
| f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a) |
| stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) |
| } |
| |
| k -= N & lt(N, k) |
| *out = b |
| out.rot(k) |
| out.mulConst(f0s, f0a) |
| out.modPhiN() |
| } |
| |
| func (out *poly) invertMod2(a *poly) { |
| // This algorithm follows mix of algorithm 10 in the paper and the first |
| // page of the PDF linked below. (Although note that the paper appears |
| // to have a bug: k should start at zero, not one.) The best explanation |
| // for why it works is in the "Why it works" section of |
| // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. |
| var k uint |
| degF, degG := uint(N-1), uint(N-1) |
| |
| var f poly2 |
| f.fromDiscrete(a) |
| var b, c, g poly2 |
| g.setPhiN() |
| b[0] = 1 |
| |
| stillGoing := ^uint(0) |
| for i := 0; i < 2*(N-1)-1; i++ { |
| s := uint(f[0]&1) & stillGoing |
| shouldSwap := ^(s - 1) & lt(degF, degG) |
| f.cswap(&g, shouldSwap) |
| b.cswap(&c, shouldSwap) |
| degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) |
| f.fmadd(s, &g) |
| b.fmadd(s, &c) |
| |
| f.rshift1() |
| c.lshift1() |
| |
| degF-- |
| k += 1 & stillGoing |
| stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) |
| } |
| |
| k -= N & lt(N, k) |
| b.rot(k) |
| out.fromMod2(&b) |
| } |
| |
| func (out *poly) invert(origA *poly) { |
| // Inversion mod Q, which is done based on the result of inverting mod |
| // 2. See the NTRU paper, page three. |
| var a, tmp, tmp2, b poly |
| b.invertMod2(origA) |
| |
| // Negate a. |
| for i := range a { |
| a[i] = Q - origA[i] |
| } |
| |
| // We are working mod Q=2**13 and we need to iterate ceil(log_2(13)) |
| // times, which is four. |
| for i := 0; i < 4; i++ { |
| tmp.mul(&a, &b) |
| tmp[0] += 2 |
| tmp2.mul(&b, &tmp) |
| b = tmp2 |
| } |
| |
| *out = b |
| } |
| |
| type PublicKey struct { |
| h poly |
| } |
| |
| func ParsePublicKey(in []byte) (*PublicKey, bool) { |
| ret := new(PublicKey) |
| if !ret.h.unmarshal(in) { |
| return nil, false |
| } |
| return ret, true |
| } |
| |
| func (pub *PublicKey) Marshal() []byte { |
| ret := make([]byte, modQBytes) |
| pub.h.marshal(ret) |
| return ret |
| } |
| |
| func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) { |
| var randBytes [352 + 352]byte |
| if _, err := io.ReadFull(rand, randBytes[:]); err != nil { |
| panic("rand failed") |
| } |
| |
| var m, r poly |
| m.shortSample(randBytes[:352]) |
| r.shortSample(randBytes[352:]) |
| |
| var mBytes, rBytes [mod3Bytes]byte |
| m.marshalS3(mBytes[:]) |
| r.marshalS3(rBytes[:]) |
| |
| ciphertext = pub.owf(&m, &r) |
| |
| h := sha256.New() |
| h.Write([]byte("shared key\x00")) |
| h.Write(mBytes[:]) |
| h.Write(rBytes[:]) |
| h.Write(ciphertext) |
| sharedKey = h.Sum(nil) |
| |
| return ciphertext, sharedKey |
| } |
| |
| func (pub *PublicKey) owf(m, r *poly) []byte { |
| for i := range r { |
| r[i] = mod3ToModQ(r[i]) |
| } |
| |
| var mq poly |
| mq.lift(m) |
| |
| var e poly |
| e.mul(r, &pub.h) |
| for i := range e { |
| e[i] = (e[i] + mq[i]) % Q |
| } |
| |
| ret := make([]byte, modQBytes) |
| e.marshal(ret[:]) |
| return ret |
| } |
| |
| type PrivateKey struct { |
| PublicKey |
| f, fp poly3 |
| hInv poly |
| hmacKey [32]byte |
| } |
| |
| func (priv *PrivateKey) Marshal() []byte { |
| var ret [2*mod3Bytes + modQBytes]byte |
| priv.f.marshal(ret[:]) |
| priv.fp.marshal(ret[mod3Bytes:]) |
| priv.h.marshal(ret[2*mod3Bytes:]) |
| return ret[:] |
| } |
| |
| func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) { |
| if len(ciphertext) != modQBytes { |
| return nil, false |
| } |
| |
| var e poly |
| if !e.unmarshal(ciphertext) { |
| return nil, false |
| } |
| |
| var f poly |
| f.fromMod3ToModQ(&priv.f) |
| |
| var v1, m poly |
| v1.mul(&e, &f) |
| |
| var v13 poly3 |
| v13.fromDiscreteMod3(&v1) |
| // Note: v13 is not reduced mod phi(n). |
| |
| var m3 poly3 |
| m3.mulMod3(&v13, &priv.fp) |
| m3.modPhiN() |
| m.fromMod3(&m3) |
| |
| var mLift, delta poly |
| mLift.lift(&m) |
| for i := range delta { |
| delta[i] = (e[i] - mLift[i] + Q) % Q |
| } |
| delta.mul(&delta, &priv.hInv) |
| delta.modPhiN() |
| |
| var r poly3 |
| allOk := r.fromModQ(&delta) |
| |
| var mBytes, rBytes [mod3Bytes]byte |
| m.marshalS3(mBytes[:]) |
| r.marshal(rBytes[:]) |
| |
| var rPoly poly |
| rPoly.fromMod3(&r) |
| expectedCiphertext := priv.PublicKey.owf(&m, &rPoly) |
| |
| allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext) |
| |
| hmacHash := hmac.New(sha256.New, priv.hmacKey[:]) |
| hmacHash.Write(ciphertext) |
| hmacDigest := hmacHash.Sum(nil) |
| |
| h := sha256.New() |
| h.Write([]byte("shared key\x00")) |
| h.Write(mBytes[:]) |
| h.Write(rBytes[:]) |
| h.Write(ciphertext) |
| sharedKey = h.Sum(nil) |
| |
| mask := uint8(allOk - 1) |
| for i := range sharedKey { |
| sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask) |
| } |
| |
| return sharedKey, true |
| } |
| |
| func GenerateKey(rand io.Reader) PrivateKey { |
| var randBytes [352 + 352]byte |
| if _, err := io.ReadFull(rand, randBytes[:]); err != nil { |
| panic("rand failed") |
| } |
| |
| var f poly |
| f.shortSamplePlus(randBytes[:352]) |
| var priv PrivateKey |
| priv.f.fromDiscrete(&f) |
| priv.fp.invertMod3(&priv.f) |
| |
| var g poly |
| g.shortSamplePlus(randBytes[352:]) |
| |
| var pgPhi1 poly |
| for i := range g { |
| pgPhi1[i] = mod3ToModQ(g[i]) |
| } |
| for i := range pgPhi1 { |
| pgPhi1[i] = (pgPhi1[i] * 3) % Q |
| } |
| pgPhi1.mulXMinus1() |
| |
| var fModQ poly |
| fModQ.fromMod3ToModQ(&priv.f) |
| |
| var pfgPhi1 poly |
| pfgPhi1.mul(&fModQ, &pgPhi1) |
| |
| var i poly |
| i.invert(&pfgPhi1) |
| |
| priv.h.mul(&i, &pgPhi1) |
| priv.h.mul(&priv.h, &pgPhi1) |
| |
| priv.hInv.mul(&i, &fModQ) |
| priv.hInv.mul(&priv.hInv, &fModQ) |
| |
| return priv |
| } |