fe_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. // Copyright (c) 2017 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package field
  5. import (
  6. "bytes"
  7. "crypto/rand"
  8. "encoding/hex"
  9. "io"
  10. "math/big"
  11. "math/bits"
  12. mathrand "math/rand"
  13. "reflect"
  14. "testing"
  15. "testing/quick"
  16. )
  17. func (v Element) String() string {
  18. return hex.EncodeToString(v.Bytes())
  19. }
  20. // quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks)
  21. // times. The default value of -quickchecks is 100.
  22. var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
  23. func generateFieldElement(rand *mathrand.Rand) Element {
  24. const maskLow52Bits = (1 << 52) - 1
  25. return Element{
  26. rand.Uint64() & maskLow52Bits,
  27. rand.Uint64() & maskLow52Bits,
  28. rand.Uint64() & maskLow52Bits,
  29. rand.Uint64() & maskLow52Bits,
  30. rand.Uint64() & maskLow52Bits,
  31. }
  32. }
  33. // weirdLimbs can be combined to generate a range of edge-case field elements.
  34. // 0 and -1 are intentionally more weighted, as they combine well.
  35. var (
  36. weirdLimbs51 = []uint64{
  37. 0, 0, 0, 0,
  38. 1,
  39. 19 - 1,
  40. 19,
  41. 0x2aaaaaaaaaaaa,
  42. 0x5555555555555,
  43. (1 << 51) - 20,
  44. (1 << 51) - 19,
  45. (1 << 51) - 1, (1 << 51) - 1,
  46. (1 << 51) - 1, (1 << 51) - 1,
  47. }
  48. weirdLimbs52 = []uint64{
  49. 0, 0, 0, 0, 0, 0,
  50. 1,
  51. 19 - 1,
  52. 19,
  53. 0x2aaaaaaaaaaaa,
  54. 0x5555555555555,
  55. (1 << 51) - 20,
  56. (1 << 51) - 19,
  57. (1 << 51) - 1, (1 << 51) - 1,
  58. (1 << 51) - 1, (1 << 51) - 1,
  59. (1 << 51) - 1, (1 << 51) - 1,
  60. 1 << 51,
  61. (1 << 51) + 1,
  62. (1 << 52) - 19,
  63. (1 << 52) - 1,
  64. }
  65. )
  66. func generateWeirdFieldElement(rand *mathrand.Rand) Element {
  67. return Element{
  68. weirdLimbs52[rand.Intn(len(weirdLimbs52))],
  69. weirdLimbs51[rand.Intn(len(weirdLimbs51))],
  70. weirdLimbs51[rand.Intn(len(weirdLimbs51))],
  71. weirdLimbs51[rand.Intn(len(weirdLimbs51))],
  72. weirdLimbs51[rand.Intn(len(weirdLimbs51))],
  73. }
  74. }
  75. func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
  76. if rand.Intn(2) == 0 {
  77. return reflect.ValueOf(generateWeirdFieldElement(rand))
  78. }
  79. return reflect.ValueOf(generateFieldElement(rand))
  80. }
  81. // isInBounds returns whether the element is within the expected bit size bounds
  82. // after a light reduction.
  83. func isInBounds(x *Element) bool {
  84. return bits.Len64(x.l0) <= 52 &&
  85. bits.Len64(x.l1) <= 52 &&
  86. bits.Len64(x.l2) <= 52 &&
  87. bits.Len64(x.l3) <= 52 &&
  88. bits.Len64(x.l4) <= 52
  89. }
  90. func TestMultiplyDistributesOverAdd(t *testing.T) {
  91. multiplyDistributesOverAdd := func(x, y, z Element) bool {
  92. // Compute t1 = (x+y)*z
  93. t1 := new(Element)
  94. t1.Add(&x, &y)
  95. t1.Multiply(t1, &z)
  96. // Compute t2 = x*z + y*z
  97. t2 := new(Element)
  98. t3 := new(Element)
  99. t2.Multiply(&x, &z)
  100. t3.Multiply(&y, &z)
  101. t2.Add(t2, t3)
  102. return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
  103. }
  104. if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
  105. t.Error(err)
  106. }
  107. }
  108. func TestMul64to128(t *testing.T) {
  109. a := uint64(5)
  110. b := uint64(5)
  111. r := mul64(a, b)
  112. if r.lo != 0x19 || r.hi != 0 {
  113. t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
  114. }
  115. a = uint64(18014398509481983) // 2^54 - 1
  116. b = uint64(18014398509481983) // 2^54 - 1
  117. r = mul64(a, b)
  118. if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
  119. t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
  120. }
  121. a = uint64(1125899906842661)
  122. b = uint64(2097155)
  123. r = mul64(a, b)
  124. r = addMul64(r, a, b)
  125. r = addMul64(r, a, b)
  126. r = addMul64(r, a, b)
  127. r = addMul64(r, a, b)
  128. if r.lo != 16888498990613035 || r.hi != 640 {
  129. t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
  130. }
  131. }
  132. func TestSetBytesRoundTrip(t *testing.T) {
  133. f1 := func(in [32]byte, fe Element) bool {
  134. fe.SetBytes(in[:])
  135. // Mask the most significant bit as it's ignored by SetBytes. (Now
  136. // instead of earlier so we check the masking in SetBytes is working.)
  137. in[len(in)-1] &= (1 << 7) - 1
  138. return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
  139. }
  140. if err := quick.Check(f1, nil); err != nil {
  141. t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
  142. }
  143. f2 := func(fe, r Element) bool {
  144. r.SetBytes(fe.Bytes())
  145. // Intentionally not using Equal not to go through Bytes again.
  146. // Calling reduce because both Generate and SetBytes can produce
  147. // non-canonical representations.
  148. fe.reduce()
  149. r.reduce()
  150. return fe == r
  151. }
  152. if err := quick.Check(f2, nil); err != nil {
  153. t.Errorf("failed FE->bytes->FE round-trip: %v", err)
  154. }
  155. // Check some fixed vectors from dalek
  156. type feRTTest struct {
  157. fe Element
  158. b []byte
  159. }
  160. var tests = []feRTTest{
  161. {
  162. fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
  163. b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
  164. },
  165. {
  166. fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
  167. b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
  168. },
  169. }
  170. for _, tt := range tests {
  171. b := tt.fe.Bytes()
  172. if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 {
  173. t.Errorf("Failed fixed roundtrip: %v", tt)
  174. }
  175. }
  176. }
  177. func swapEndianness(buf []byte) []byte {
  178. for i := 0; i < len(buf)/2; i++ {
  179. buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
  180. }
  181. return buf
  182. }
  183. func TestBytesBigEquivalence(t *testing.T) {
  184. f1 := func(in [32]byte, fe, fe1 Element) bool {
  185. fe.SetBytes(in[:])
  186. in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
  187. b := new(big.Int).SetBytes(swapEndianness(in[:]))
  188. fe1.fromBig(b)
  189. if fe != fe1 {
  190. return false
  191. }
  192. buf := make([]byte, 32) // pad with zeroes
  193. copy(buf, swapEndianness(fe1.toBig().Bytes()))
  194. return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
  195. }
  196. if err := quick.Check(f1, nil); err != nil {
  197. t.Error(err)
  198. }
  199. }
  200. // fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
  201. func (v *Element) fromBig(n *big.Int) *Element {
  202. if n.BitLen() > 32*8 {
  203. panic("edwards25519: invalid field element input size")
  204. }
  205. buf := make([]byte, 0, 32)
  206. for _, word := range n.Bits() {
  207. for i := 0; i < bits.UintSize; i += 8 {
  208. if len(buf) >= cap(buf) {
  209. break
  210. }
  211. buf = append(buf, byte(word))
  212. word >>= 8
  213. }
  214. }
  215. return v.SetBytes(buf[:32])
  216. }
  217. func (v *Element) fromDecimal(s string) *Element {
  218. n, ok := new(big.Int).SetString(s, 10)
  219. if !ok {
  220. panic("not a valid decimal: " + s)
  221. }
  222. return v.fromBig(n)
  223. }
  224. // toBig returns v as a big.Int.
  225. func (v *Element) toBig() *big.Int {
  226. buf := v.Bytes()
  227. words := make([]big.Word, 32*8/bits.UintSize)
  228. for n := range words {
  229. for i := 0; i < bits.UintSize; i += 8 {
  230. if len(buf) == 0 {
  231. break
  232. }
  233. words[n] |= big.Word(buf[0]) << big.Word(i)
  234. buf = buf[1:]
  235. }
  236. }
  237. return new(big.Int).SetBits(words)
  238. }
  239. func TestDecimalConstants(t *testing.T) {
  240. sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
  241. if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
  242. t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
  243. }
  244. // d is in the parent package, and we don't want to expose d or fromDecimal.
  245. // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
  246. // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
  247. // t.Errorf("d is %v, expected %v", d, exp)
  248. // }
  249. }
  250. func TestSetBytesRoundTripEdgeCases(t *testing.T) {
  251. // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
  252. // and between 2^255 and 2^256-1. Test both the documented SetBytes
  253. // behavior, and that Bytes reduces them.
  254. }
  255. // Tests self-consistency between Multiply and Square.
  256. func TestConsistency(t *testing.T) {
  257. var x Element
  258. var x2, x2sq Element
  259. x = Element{1, 1, 1, 1, 1}
  260. x2.Multiply(&x, &x)
  261. x2sq.Square(&x)
  262. if x2 != x2sq {
  263. t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
  264. }
  265. var bytes [32]byte
  266. _, err := io.ReadFull(rand.Reader, bytes[:])
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. x.SetBytes(bytes[:])
  271. x2.Multiply(&x, &x)
  272. x2sq.Square(&x)
  273. if x2 != x2sq {
  274. t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
  275. }
  276. }
  277. func TestEqual(t *testing.T) {
  278. x := Element{1, 1, 1, 1, 1}
  279. y := Element{5, 4, 3, 2, 1}
  280. eq := x.Equal(&x)
  281. if eq != 1 {
  282. t.Errorf("wrong about equality")
  283. }
  284. eq = x.Equal(&y)
  285. if eq != 0 {
  286. t.Errorf("wrong about inequality")
  287. }
  288. }
  289. func TestInvert(t *testing.T) {
  290. x := Element{1, 1, 1, 1, 1}
  291. one := Element{1, 0, 0, 0, 0}
  292. var xinv, r Element
  293. xinv.Invert(&x)
  294. r.Multiply(&x, &xinv)
  295. r.reduce()
  296. if one != r {
  297. t.Errorf("inversion identity failed, got: %x", r)
  298. }
  299. var bytes [32]byte
  300. _, err := io.ReadFull(rand.Reader, bytes[:])
  301. if err != nil {
  302. t.Fatal(err)
  303. }
  304. x.SetBytes(bytes[:])
  305. xinv.Invert(&x)
  306. r.Multiply(&x, &xinv)
  307. r.reduce()
  308. if one != r {
  309. t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
  310. }
  311. zero := Element{}
  312. x.Set(&zero)
  313. if xx := xinv.Invert(&x); xx != &xinv {
  314. t.Errorf("inverting zero did not return the receiver")
  315. } else if xinv.Equal(&zero) != 1 {
  316. t.Errorf("inverting zero did not return zero")
  317. }
  318. }
  319. func TestSelectSwap(t *testing.T) {
  320. a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
  321. b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
  322. var c, d Element
  323. c.Select(&a, &b, 1)
  324. d.Select(&a, &b, 0)
  325. if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
  326. t.Errorf("Select failed")
  327. }
  328. c.Swap(&d, 0)
  329. if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
  330. t.Errorf("Swap failed")
  331. }
  332. c.Swap(&d, 1)
  333. if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
  334. t.Errorf("Swap failed")
  335. }
  336. }
  337. func TestMult32(t *testing.T) {
  338. mult32EquivalentToMul := func(x Element, y uint32) bool {
  339. t1 := new(Element)
  340. for i := 0; i < 100; i++ {
  341. t1.Mult32(&x, y)
  342. }
  343. ty := new(Element)
  344. ty.l0 = uint64(y)
  345. t2 := new(Element)
  346. for i := 0; i < 100; i++ {
  347. t2.Multiply(&x, ty)
  348. }
  349. return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
  350. }
  351. if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
  352. t.Error(err)
  353. }
  354. }
  355. func TestSqrtRatio(t *testing.T) {
  356. // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
  357. type test struct {
  358. u, v string
  359. wasSquare int
  360. r string
  361. }
  362. var tests = []test{
  363. // If u is 0, the function is defined to return (0, TRUE), even if v
  364. // is zero. Note that where used in this package, the denominator v
  365. // is never zero.
  366. {
  367. "0000000000000000000000000000000000000000000000000000000000000000",
  368. "0000000000000000000000000000000000000000000000000000000000000000",
  369. 1, "0000000000000000000000000000000000000000000000000000000000000000",
  370. },
  371. // 0/1 == 0²
  372. {
  373. "0000000000000000000000000000000000000000000000000000000000000000",
  374. "0100000000000000000000000000000000000000000000000000000000000000",
  375. 1, "0000000000000000000000000000000000000000000000000000000000000000",
  376. },
  377. // If u is non-zero and v is zero, defined to return (0, FALSE).
  378. {
  379. "0100000000000000000000000000000000000000000000000000000000000000",
  380. "0000000000000000000000000000000000000000000000000000000000000000",
  381. 0, "0000000000000000000000000000000000000000000000000000000000000000",
  382. },
  383. // 2/1 is not square in this field.
  384. {
  385. "0200000000000000000000000000000000000000000000000000000000000000",
  386. "0100000000000000000000000000000000000000000000000000000000000000",
  387. 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
  388. },
  389. // 4/1 == 2²
  390. {
  391. "0400000000000000000000000000000000000000000000000000000000000000",
  392. "0100000000000000000000000000000000000000000000000000000000000000",
  393. 1, "0200000000000000000000000000000000000000000000000000000000000000",
  394. },
  395. // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
  396. {
  397. "0100000000000000000000000000000000000000000000000000000000000000",
  398. "0400000000000000000000000000000000000000000000000000000000000000",
  399. 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
  400. },
  401. }
  402. for i, tt := range tests {
  403. u := new(Element).SetBytes(decodeHex(tt.u))
  404. v := new(Element).SetBytes(decodeHex(tt.v))
  405. want := new(Element).SetBytes(decodeHex(tt.r))
  406. got, wasSquare := new(Element).SqrtRatio(u, v)
  407. if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
  408. t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
  409. }
  410. }
  411. }
  412. func TestCarryPropagate(t *testing.T) {
  413. asmLikeGeneric := func(a [5]uint64) bool {
  414. t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
  415. t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
  416. t1.carryPropagate()
  417. t2.carryPropagateGeneric()
  418. if *t1 != *t2 {
  419. t.Logf("got: %#v,\nexpected: %#v", t1, t2)
  420. }
  421. return *t1 == *t2 && isInBounds(t2)
  422. }
  423. if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
  424. t.Error(err)
  425. }
  426. if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
  427. t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
  428. }
  429. }
  430. func TestFeSquare(t *testing.T) {
  431. asmLikeGeneric := func(a Element) bool {
  432. t1 := a
  433. t2 := a
  434. feSquareGeneric(&t1, &t1)
  435. feSquare(&t2, &t2)
  436. if t1 != t2 {
  437. t.Logf("got: %#v,\nexpected: %#v", t1, t2)
  438. }
  439. return t1 == t2 && isInBounds(&t2)
  440. }
  441. if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
  442. t.Error(err)
  443. }
  444. }
  445. func TestFeMul(t *testing.T) {
  446. asmLikeGeneric := func(a, b Element) bool {
  447. a1 := a
  448. a2 := a
  449. b1 := b
  450. b2 := b
  451. feMulGeneric(&a1, &a1, &b1)
  452. feMul(&a2, &a2, &b2)
  453. if a1 != a2 || b1 != b2 {
  454. t.Logf("got: %#v,\nexpected: %#v", a1, a2)
  455. t.Logf("got: %#v,\nexpected: %#v", b1, b2)
  456. }
  457. return a1 == a2 && isInBounds(&a2) &&
  458. b1 == b2 && isInBounds(&b2)
  459. }
  460. if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
  461. t.Error(err)
  462. }
  463. }
  464. func decodeHex(s string) []byte {
  465. b, err := hex.DecodeString(s)
  466. if err != nil {
  467. panic(err)
  468. }
  469. return b
  470. }