Add drain and close function (#910)
This commit is contained in:
commit
de24996538
@ -2,7 +2,6 @@ version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- errcheck
|
||||
- gocritic
|
||||
- govet
|
||||
|
@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- [#905](https://github.com/spegel-org/spegel/pull/905) Change mirror type to url and add byte range parameter.
|
||||
- [#909](https://github.com/spegel-org/spegel/pull/909) Add base http client and transport.
|
||||
- [#910](https://github.com/spegel-org/spegel/pull/910) Add drain and close function.
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -3,7 +3,6 @@ package cleanup
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -130,11 +129,7 @@ func probeIPs(ctx context.Context, client *http.Client, ips []net.IPAddr, port s
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = io.Copy(io.Discard, resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer httpx.DrainAndClose(resp.Body)
|
||||
err = httpx.CheckResponseStatus(resp, http.StatusOK)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1,6 +1,8 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
@ -41,3 +43,25 @@ func BaseTransport() *http.Transport {
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// MaxReadBytes is the maximum amount of bytes read when draining a response or reading error message.
|
||||
MaxReadBytes = 512 * 1024
|
||||
)
|
||||
|
||||
// DrainAndCloses empties the body buffer before closing the body.
|
||||
func DrainAndClose(rc io.ReadCloser) error {
|
||||
errs := []error{}
|
||||
n, err := io.Copy(io.Discard, io.LimitReader(rc, MaxReadBytes+1))
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if n > MaxReadBytes {
|
||||
errs = append(errs, errors.New("reader has more data than max read bytes"))
|
||||
}
|
||||
err = rc.Close()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
@ -22,3 +24,22 @@ func TestBaseTransport(t *testing.T) {
|
||||
|
||||
BaseTransport()
|
||||
}
|
||||
|
||||
func TestDrainAndClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
err := DrainAndClose(io.NopCloser(buf))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, buf.Bytes())
|
||||
|
||||
buf = bytes.NewBuffer(make([]byte, MaxReadBytes))
|
||||
err = DrainAndClose(io.NopCloser(buf))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, buf.Bytes())
|
||||
|
||||
buf = bytes.NewBuffer(make([]byte, MaxReadBytes+10))
|
||||
err = DrainAndClose(io.NopCloser(buf))
|
||||
require.EqualError(t, err, "reader has more data than max read bytes")
|
||||
require.Len(t, buf.Bytes(), 9)
|
||||
}
|
||||
|
@ -55,9 +55,11 @@ func getErrorMessage(resp *http.Response) (string, error) {
|
||||
"application/xml",
|
||||
}
|
||||
if !slices.Contains(contentTypes, resp.Header.Get(HeaderContentType)) {
|
||||
_, err := io.Copy(io.Discard, resp.Body)
|
||||
return "", nil
|
||||
}
|
||||
b, err := io.ReadAll(io.LimitReader(resp.Body, MaxReadBytes))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
return string(b), err
|
||||
}
|
||||
|
@ -63,66 +63,74 @@ func (c *Client) Pull(ctx context.Context, img Image, mirror *url.URL) ([]PullMe
|
||||
queue = queue[1:]
|
||||
|
||||
start := time.Now()
|
||||
rc, desc, err := c.Get(ctx, dist, mirror, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
desc, err := func() (ocispec.Descriptor, error) {
|
||||
rc, desc, err := c.Get(ctx, dist, mirror, nil)
|
||||
if err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
defer httpx.DrainAndClose(rc)
|
||||
|
||||
switch dist.Kind {
|
||||
case DistributionKindBlob:
|
||||
_, copyErr := io.Copy(io.Discard, rc)
|
||||
closeErr := rc.Close()
|
||||
err := errors.Join(copyErr, closeErr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case DistributionKindManifest:
|
||||
b, readErr := io.ReadAll(rc)
|
||||
closeErr := rc.Close()
|
||||
err = errors.Join(readErr, closeErr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch desc.MediaType {
|
||||
case images.MediaTypeDockerSchema2ManifestList, ocispec.MediaTypeImageIndex:
|
||||
var idx ocispec.Index
|
||||
if err := json.Unmarshal(b, &idx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, m := range idx.Manifests {
|
||||
// TODO: Add platform option.
|
||||
//nolint: staticcheck // Simplify in the future.
|
||||
if !(m.Platform.OS == runtime.GOOS && m.Platform.Architecture == runtime.GOARCH) {
|
||||
continue
|
||||
}
|
||||
queue = append(queue, DistributionPath{
|
||||
Kind: DistributionKindManifest,
|
||||
Name: dist.Name,
|
||||
Digest: m.Digest,
|
||||
Registry: dist.Registry,
|
||||
})
|
||||
}
|
||||
case images.MediaTypeDockerSchema2Manifest, ocispec.MediaTypeImageManifest:
|
||||
var manifest ocispec.Manifest
|
||||
err := json.Unmarshal(b, &manifest)
|
||||
switch dist.Kind {
|
||||
case DistributionKindBlob:
|
||||
// Right now we are just discarding the contents because we do not have a writable store.
|
||||
_, copyErr := io.Copy(io.Discard, rc)
|
||||
closeErr := rc.Close()
|
||||
err := errors.Join(copyErr, closeErr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
queue = append(queue, DistributionPath{
|
||||
Kind: DistributionKindBlob,
|
||||
Name: dist.Name,
|
||||
Digest: manifest.Config.Digest,
|
||||
Registry: dist.Registry,
|
||||
})
|
||||
for _, layer := range manifest.Layers {
|
||||
case DistributionKindManifest:
|
||||
b, readErr := io.ReadAll(rc)
|
||||
closeErr := rc.Close()
|
||||
err = errors.Join(readErr, closeErr)
|
||||
if err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
switch desc.MediaType {
|
||||
case images.MediaTypeDockerSchema2ManifestList, ocispec.MediaTypeImageIndex:
|
||||
var idx ocispec.Index
|
||||
if err := json.Unmarshal(b, &idx); err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
for _, m := range idx.Manifests {
|
||||
// TODO: Add platform option.
|
||||
//nolint: staticcheck // Simplify in the future.
|
||||
if !(m.Platform.OS == runtime.GOOS && m.Platform.Architecture == runtime.GOARCH) {
|
||||
continue
|
||||
}
|
||||
queue = append(queue, DistributionPath{
|
||||
Kind: DistributionKindManifest,
|
||||
Name: dist.Name,
|
||||
Digest: m.Digest,
|
||||
Registry: dist.Registry,
|
||||
})
|
||||
}
|
||||
case images.MediaTypeDockerSchema2Manifest, ocispec.MediaTypeImageManifest:
|
||||
var manifest ocispec.Manifest
|
||||
err := json.Unmarshal(b, &manifest)
|
||||
if err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
queue = append(queue, DistributionPath{
|
||||
Kind: DistributionKindBlob,
|
||||
Name: dist.Name,
|
||||
Digest: layer.Digest,
|
||||
Digest: manifest.Config.Digest,
|
||||
Registry: dist.Registry,
|
||||
})
|
||||
for _, layer := range manifest.Layers {
|
||||
queue = append(queue, DistributionPath{
|
||||
Kind: DistributionKindBlob,
|
||||
Name: dist.Name,
|
||||
Digest: layer.Digest,
|
||||
Registry: dist.Registry,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return desc, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metric := PullMetric{
|
||||
@ -142,11 +150,7 @@ func (c *Client) Head(ctx context.Context, dist DistributionPath, mirror *url.UR
|
||||
if err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
defer rc.Close()
|
||||
_, err = io.Copy(io.Discard, rc)
|
||||
if err != nil {
|
||||
return ocispec.Descriptor{}, err
|
||||
}
|
||||
defer httpx.DrainAndClose(rc)
|
||||
return desc, nil
|
||||
}
|
||||
|
||||
|
@ -364,16 +364,10 @@ func forwardRequest(client *http.Client, bufferPool *sync.Pool, req *http.Reques
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer forwardResp.Body.Close()
|
||||
|
||||
// Clear body and try next if non 200 response.
|
||||
//nolint:staticcheck // Keep things readable.
|
||||
if !(forwardResp.StatusCode == http.StatusOK || forwardResp.StatusCode == http.StatusPartialContent) {
|
||||
_, err = io.Copy(io.Discard, forwardResp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("expected mirror to respond with 200 OK but received: %s", forwardResp.Status)
|
||||
defer httpx.DrainAndClose(forwardResp.Body)
|
||||
err = httpx.CheckResponseStatus(forwardResp, http.StatusOK, http.StatusPartialContent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO (phillebaba): Is it possible to retry if copy fails half way through?
|
||||
|
Loading…
x
Reference in New Issue
Block a user