readfrom_linux_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. // Copyright 2020 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package os_test
  5. import (
  6. "bytes"
  7. "internal/poll"
  8. "io"
  9. "math/rand"
  10. "os"
  11. . "os"
  12. "path/filepath"
  13. "strconv"
  14. "syscall"
  15. "testing"
  16. "time"
  17. )
  18. func TestCopyFileRange(t *testing.T) {
  19. sizes := []int{
  20. 1,
  21. 42,
  22. 1025,
  23. syscall.Getpagesize() + 1,
  24. 32769,
  25. }
  26. t.Run("Basic", func(t *testing.T) {
  27. for _, size := range sizes {
  28. t.Run(strconv.Itoa(size), func(t *testing.T) {
  29. testCopyFileRange(t, int64(size), -1)
  30. })
  31. }
  32. })
  33. t.Run("Limited", func(t *testing.T) {
  34. t.Run("OneLess", func(t *testing.T) {
  35. for _, size := range sizes {
  36. t.Run(strconv.Itoa(size), func(t *testing.T) {
  37. testCopyFileRange(t, int64(size), int64(size)-1)
  38. })
  39. }
  40. })
  41. t.Run("Half", func(t *testing.T) {
  42. for _, size := range sizes {
  43. t.Run(strconv.Itoa(size), func(t *testing.T) {
  44. testCopyFileRange(t, int64(size), int64(size)/2)
  45. })
  46. }
  47. })
  48. t.Run("More", func(t *testing.T) {
  49. for _, size := range sizes {
  50. t.Run(strconv.Itoa(size), func(t *testing.T) {
  51. testCopyFileRange(t, int64(size), int64(size)+7)
  52. })
  53. }
  54. })
  55. })
  56. t.Run("DoesntTryInAppendMode", func(t *testing.T) {
  57. dst, src, data, hook := newCopyFileRangeTest(t, 42)
  58. dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. defer dst2.Close()
  63. if _, err := io.Copy(dst2, src); err != nil {
  64. t.Fatal(err)
  65. }
  66. if hook.called {
  67. t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
  68. }
  69. mustSeekStart(t, dst2)
  70. mustContainData(t, dst2, data) // through traditional means
  71. })
  72. t.Run("NotRegular", func(t *testing.T) {
  73. t.Run("BothPipes", func(t *testing.T) {
  74. hook := hookCopyFileRange(t)
  75. pr1, pw1, err := Pipe()
  76. if err != nil {
  77. t.Fatal(err)
  78. }
  79. defer pr1.Close()
  80. defer pw1.Close()
  81. pr2, pw2, err := Pipe()
  82. if err != nil {
  83. t.Fatal(err)
  84. }
  85. defer pr2.Close()
  86. defer pw2.Close()
  87. // The pipe is empty, and PIPE_BUF is large enough
  88. // for this, by (POSIX) definition, so there is no
  89. // need for an additional goroutine.
  90. data := []byte("hello")
  91. if _, err := pw1.Write(data); err != nil {
  92. t.Fatal(err)
  93. }
  94. pw1.Close()
  95. n, err := io.Copy(pw2, pr1)
  96. if err != nil {
  97. t.Fatal(err)
  98. }
  99. if n != int64(len(data)) {
  100. t.Fatalf("transferred %d, want %d", n, len(data))
  101. }
  102. if !hook.called {
  103. t.Fatalf("should have called poll.CopyFileRange")
  104. }
  105. pw2.Close()
  106. mustContainData(t, pr2, data)
  107. })
  108. t.Run("DstPipe", func(t *testing.T) {
  109. dst, src, data, hook := newCopyFileRangeTest(t, 255)
  110. dst.Close()
  111. pr, pw, err := Pipe()
  112. if err != nil {
  113. t.Fatal(err)
  114. }
  115. defer pr.Close()
  116. defer pw.Close()
  117. n, err := io.Copy(pw, src)
  118. if err != nil {
  119. t.Fatal(err)
  120. }
  121. if n != int64(len(data)) {
  122. t.Fatalf("transferred %d, want %d", n, len(data))
  123. }
  124. if !hook.called {
  125. t.Fatalf("should have called poll.CopyFileRange")
  126. }
  127. pw.Close()
  128. mustContainData(t, pr, data)
  129. })
  130. t.Run("SrcPipe", func(t *testing.T) {
  131. dst, src, data, hook := newCopyFileRangeTest(t, 255)
  132. src.Close()
  133. pr, pw, err := Pipe()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. defer pr.Close()
  138. defer pw.Close()
  139. // The pipe is empty, and PIPE_BUF is large enough
  140. // for this, by (POSIX) definition, so there is no
  141. // need for an additional goroutine.
  142. if _, err := pw.Write(data); err != nil {
  143. t.Fatal(err)
  144. }
  145. pw.Close()
  146. n, err := io.Copy(dst, pr)
  147. if err != nil {
  148. t.Fatal(err)
  149. }
  150. if n != int64(len(data)) {
  151. t.Fatalf("transferred %d, want %d", n, len(data))
  152. }
  153. if !hook.called {
  154. t.Fatalf("should have called poll.CopyFileRange")
  155. }
  156. mustSeekStart(t, dst)
  157. mustContainData(t, dst, data)
  158. })
  159. })
  160. t.Run("Nil", func(t *testing.T) {
  161. var nilFile *File
  162. anyFile, err := os.CreateTemp("", "")
  163. if err != nil {
  164. t.Fatal(err)
  165. }
  166. defer Remove(anyFile.Name())
  167. defer anyFile.Close()
  168. if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
  169. t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
  170. }
  171. if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
  172. t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
  173. }
  174. if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
  175. t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
  176. }
  177. if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
  178. t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
  179. }
  180. if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
  181. t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
  182. }
  183. if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
  184. t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
  185. }
  186. })
  187. }
  188. func testCopyFileRange(t *testing.T, size int64, limit int64) {
  189. dst, src, data, hook := newCopyFileRangeTest(t, size)
  190. // If we have a limit, wrap the reader.
  191. var (
  192. realsrc io.Reader
  193. lr *io.LimitedReader
  194. )
  195. if limit >= 0 {
  196. lr = &io.LimitedReader{N: limit, R: src}
  197. realsrc = lr
  198. if limit < int64(len(data)) {
  199. data = data[:limit]
  200. }
  201. } else {
  202. realsrc = src
  203. }
  204. // Now call ReadFrom (through io.Copy), which will hopefully call
  205. // poll.CopyFileRange.
  206. n, err := io.Copy(dst, realsrc)
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. // If we didn't have a limit, we should have called poll.CopyFileRange
  211. // with the right file descriptor arguments.
  212. if limit > 0 && !hook.called {
  213. t.Fatal("never called poll.CopyFileRange")
  214. }
  215. if hook.called && hook.dstfd != int(dst.Fd()) {
  216. t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
  217. }
  218. if hook.called && hook.srcfd != int(src.Fd()) {
  219. t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
  220. }
  221. // Check that the offsets after the transfer make sense, that the size
  222. // of the transfer was reported correctly, and that the destination
  223. // file contains exactly the bytes we expect it to contain.
  224. dstoff, err := dst.Seek(0, io.SeekCurrent)
  225. if err != nil {
  226. t.Fatal(err)
  227. }
  228. srcoff, err := src.Seek(0, io.SeekCurrent)
  229. if err != nil {
  230. t.Fatal(err)
  231. }
  232. if dstoff != srcoff {
  233. t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
  234. }
  235. if dstoff != int64(len(data)) {
  236. t.Errorf("dstoff = %d, want %d", dstoff, len(data))
  237. }
  238. if n != int64(len(data)) {
  239. t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
  240. }
  241. mustSeekStart(t, dst)
  242. mustContainData(t, dst, data)
  243. // If we had a limit, check that it was updated.
  244. if lr != nil {
  245. if want := limit - n; lr.N != want {
  246. t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
  247. }
  248. }
  249. }
  250. // newCopyFileRangeTest initializes a new test for copy_file_range.
  251. //
  252. // It creates source and destination files, and populates the source file
  253. // with random data of the specified size. It also hooks package os' call
  254. // to poll.CopyFileRange and returns the hook so it can be inspected.
  255. func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
  256. t.Helper()
  257. hook = hookCopyFileRange(t)
  258. tmp := t.TempDir()
  259. src, err := Create(filepath.Join(tmp, "src"))
  260. if err != nil {
  261. t.Fatal(err)
  262. }
  263. t.Cleanup(func() { src.Close() })
  264. dst, err = Create(filepath.Join(tmp, "dst"))
  265. if err != nil {
  266. t.Fatal(err)
  267. }
  268. t.Cleanup(func() { dst.Close() })
  269. // Populate the source file with data, then rewind it, so it can be
  270. // consumed by copy_file_range(2).
  271. prng := rand.New(rand.NewSource(time.Now().Unix()))
  272. data = make([]byte, size)
  273. prng.Read(data)
  274. if _, err := src.Write(data); err != nil {
  275. t.Fatal(err)
  276. }
  277. if _, err := src.Seek(0, io.SeekStart); err != nil {
  278. t.Fatal(err)
  279. }
  280. return dst, src, data, hook
  281. }
  282. // mustContainData ensures that the specified file contains exactly the
  283. // specified data.
  284. func mustContainData(t *testing.T, f *File, data []byte) {
  285. t.Helper()
  286. got := make([]byte, len(data))
  287. if _, err := io.ReadFull(f, got); err != nil {
  288. t.Fatal(err)
  289. }
  290. if !bytes.Equal(got, data) {
  291. t.Fatalf("didn't get the same data back from %s", f.Name())
  292. }
  293. if _, err := f.Read(make([]byte, 1)); err != io.EOF {
  294. t.Fatalf("not at EOF")
  295. }
  296. }
  297. func mustSeekStart(t *testing.T, f *File) {
  298. if _, err := f.Seek(0, io.SeekStart); err != nil {
  299. t.Fatal(err)
  300. }
  301. }
  302. func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
  303. h := new(copyFileRangeHook)
  304. h.install()
  305. t.Cleanup(h.uninstall)
  306. return h
  307. }
  308. type copyFileRangeHook struct {
  309. called bool
  310. dstfd int
  311. srcfd int
  312. remain int64
  313. original func(dst, src *poll.FD, remain int64) (int64, bool, error)
  314. }
  315. func (h *copyFileRangeHook) install() {
  316. h.original = *PollCopyFileRangeP
  317. *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
  318. h.called = true
  319. h.dstfd = dst.Sysfd
  320. h.srcfd = src.Sysfd
  321. h.remain = remain
  322. return h.original(dst, src, remain)
  323. }
  324. }
  325. func (h *copyFileRangeHook) uninstall() {
  326. *PollCopyFileRangeP = h.original
  327. }
  328. // On some kernels copy_file_range fails on files in /proc.
  329. func TestProcCopy(t *testing.T) {
  330. const cmdlineFile = "/proc/self/cmdline"
  331. cmdline, err := os.ReadFile(cmdlineFile)
  332. if err != nil {
  333. t.Skipf("can't read /proc file: %v", err)
  334. }
  335. in, err := os.Open(cmdlineFile)
  336. if err != nil {
  337. t.Fatal(err)
  338. }
  339. defer in.Close()
  340. outFile := filepath.Join(t.TempDir(), "cmdline")
  341. out, err := os.Create(outFile)
  342. if err != nil {
  343. t.Fatal(err)
  344. }
  345. if _, err := io.Copy(out, in); err != nil {
  346. t.Fatal(err)
  347. }
  348. if err := out.Close(); err != nil {
  349. t.Fatal(err)
  350. }
  351. copy, err := os.ReadFile(outFile)
  352. if err != nil {
  353. t.Fatal(err)
  354. }
  355. if !bytes.Equal(cmdline, copy) {
  356. t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
  357. }
  358. }