Add drain and close function

Signed-off-by: Philip Laine <philip.laine@gmail.com>
This commit is contained in:
Philip Laine 2025-06-05 15:13:03 +02:00
parent ab4d9a5d4d
commit b56f7baa5c
No known key found for this signature in database
GPG Key ID: F6D0B743CA3EFF33
8 changed files with 114 additions and 74 deletions

View File

@ -2,7 +2,6 @@ version: "2"
linters:
default: none
enable:
- bodyclose
- errcheck
- gocritic
- govet

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [#905](https://github.com/spegel-org/spegel/pull/905) Change mirror type to url and add byte range parameter.
- [#909](https://github.com/spegel-org/spegel/pull/909) Add base http client and transport.
- [#910](https://github.com/spegel-org/spegel/pull/910) Add drain and close function.
### Changed

View File

@ -3,7 +3,6 @@ package cleanup
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/url"
@ -130,11 +129,7 @@ func probeIPs(ctx context.Context, client *http.Client, ips []net.IPAddr, port s
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(io.Discard, resp.Body)
if err != nil {
return err
}
defer httpx.DrainAndClose(resp.Body)
err = httpx.CheckResponseStatus(resp, http.StatusOK)
if err != nil {
return err

View File

@ -1,6 +1,8 @@
package httpx
import (
"errors"
"io"
"net"
"net/http"
"time"
@ -41,3 +43,25 @@ func BaseTransport() *http.Transport {
ExpectContinueTimeout: 1 * time.Second,
}
}
const (
// MaxReadBytes is the maximum amount of bytes read when draining a response or reading error message.
MaxReadBytes = 512 * 1024
)
// DrainAndCloses empties the body buffer before closing the body.
func DrainAndClose(rc io.ReadCloser) error {
errs := []error{}
n, err := io.Copy(io.Discard, io.LimitReader(rc, MaxReadBytes+1))
if err != nil {
errs = append(errs, err)
}
if n > MaxReadBytes {
errs = append(errs, errors.New("reader has more data than max read bytes"))
}
err = rc.Close()
if err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}

View File

@ -1,6 +1,8 @@
package httpx
import (
"bytes"
"io"
"net/http"
"testing"
"time"
@ -22,3 +24,22 @@ func TestBaseTransport(t *testing.T) {
BaseTransport()
}
func TestDrainAndClose(t *testing.T) {
t.Parallel()
buf := bytes.NewBuffer(nil)
err := DrainAndClose(io.NopCloser(buf))
require.NoError(t, err)
require.Empty(t, buf.Bytes())
buf = bytes.NewBuffer(make([]byte, MaxReadBytes))
err = DrainAndClose(io.NopCloser(buf))
require.NoError(t, err)
require.Empty(t, buf.Bytes())
buf = bytes.NewBuffer(make([]byte, MaxReadBytes+10))
err = DrainAndClose(io.NopCloser(buf))
require.EqualError(t, err, "reader has more data than max read bytes")
require.Len(t, buf.Bytes(), 9)
}

View File

@ -55,9 +55,11 @@ func getErrorMessage(resp *http.Response) (string, error) {
"application/xml",
}
if !slices.Contains(contentTypes, resp.Header.Get(HeaderContentType)) {
_, err := io.Copy(io.Discard, resp.Body)
return "", nil
}
b, err := io.ReadAll(io.LimitReader(resp.Body, MaxReadBytes))
if err != nil {
return "", err
}
b, err := io.ReadAll(resp.Body)
return string(b), err
}

View File

@ -63,66 +63,74 @@ func (c *Client) Pull(ctx context.Context, img Image, mirror *url.URL) ([]PullMe
queue = queue[1:]
start := time.Now()
rc, desc, err := c.Get(ctx, dist, mirror, nil)
if err != nil {
return nil, err
}
desc, err := func() (ocispec.Descriptor, error) {
rc, desc, err := c.Get(ctx, dist, mirror, nil)
if err != nil {
return ocispec.Descriptor{}, err
}
defer httpx.DrainAndClose(rc)
switch dist.Kind {
case DistributionKindBlob:
_, copyErr := io.Copy(io.Discard, rc)
closeErr := rc.Close()
err := errors.Join(copyErr, closeErr)
if err != nil {
return nil, err
}
case DistributionKindManifest:
b, readErr := io.ReadAll(rc)
closeErr := rc.Close()
err = errors.Join(readErr, closeErr)
if err != nil {
return nil, err
}
switch desc.MediaType {
case images.MediaTypeDockerSchema2ManifestList, ocispec.MediaTypeImageIndex:
var idx ocispec.Index
if err := json.Unmarshal(b, &idx); err != nil {
return nil, err
}
for _, m := range idx.Manifests {
// TODO: Add platform option.
//nolint: staticcheck // Simplify in the future.
if !(m.Platform.OS == runtime.GOOS && m.Platform.Architecture == runtime.GOARCH) {
continue
}
queue = append(queue, DistributionPath{
Kind: DistributionKindManifest,
Name: dist.Name,
Digest: m.Digest,
Registry: dist.Registry,
})
}
case images.MediaTypeDockerSchema2Manifest, ocispec.MediaTypeImageManifest:
var manifest ocispec.Manifest
err := json.Unmarshal(b, &manifest)
switch dist.Kind {
case DistributionKindBlob:
// Right now we are just discarding the contents because we do not have a writable store.
_, copyErr := io.Copy(io.Discard, rc)
closeErr := rc.Close()
err := errors.Join(copyErr, closeErr)
if err != nil {
return nil, err
return ocispec.Descriptor{}, err
}
queue = append(queue, DistributionPath{
Kind: DistributionKindBlob,
Name: dist.Name,
Digest: manifest.Config.Digest,
Registry: dist.Registry,
})
for _, layer := range manifest.Layers {
case DistributionKindManifest:
b, readErr := io.ReadAll(rc)
closeErr := rc.Close()
err = errors.Join(readErr, closeErr)
if err != nil {
return ocispec.Descriptor{}, err
}
switch desc.MediaType {
case images.MediaTypeDockerSchema2ManifestList, ocispec.MediaTypeImageIndex:
var idx ocispec.Index
if err := json.Unmarshal(b, &idx); err != nil {
return ocispec.Descriptor{}, err
}
for _, m := range idx.Manifests {
// TODO: Add platform option.
//nolint: staticcheck // Simplify in the future.
if !(m.Platform.OS == runtime.GOOS && m.Platform.Architecture == runtime.GOARCH) {
continue
}
queue = append(queue, DistributionPath{
Kind: DistributionKindManifest,
Name: dist.Name,
Digest: m.Digest,
Registry: dist.Registry,
})
}
case images.MediaTypeDockerSchema2Manifest, ocispec.MediaTypeImageManifest:
var manifest ocispec.Manifest
err := json.Unmarshal(b, &manifest)
if err != nil {
return ocispec.Descriptor{}, err
}
queue = append(queue, DistributionPath{
Kind: DistributionKindBlob,
Name: dist.Name,
Digest: layer.Digest,
Digest: manifest.Config.Digest,
Registry: dist.Registry,
})
for _, layer := range manifest.Layers {
queue = append(queue, DistributionPath{
Kind: DistributionKindBlob,
Name: dist.Name,
Digest: layer.Digest,
Registry: dist.Registry,
})
}
}
}
return desc, nil
}()
if err != nil {
return nil, err
}
metric := PullMetric{
@ -142,11 +150,7 @@ func (c *Client) Head(ctx context.Context, dist DistributionPath, mirror *url.UR
if err != nil {
return ocispec.Descriptor{}, err
}
defer rc.Close()
_, err = io.Copy(io.Discard, rc)
if err != nil {
return ocispec.Descriptor{}, err
}
defer httpx.DrainAndClose(rc)
return desc, nil
}

View File

@ -364,16 +364,10 @@ func forwardRequest(client *http.Client, bufferPool *sync.Pool, req *http.Reques
if err != nil {
return err
}
defer forwardResp.Body.Close()
// Clear body and try next if non 200 response.
//nolint:staticcheck // Keep things readable.
if !(forwardResp.StatusCode == http.StatusOK || forwardResp.StatusCode == http.StatusPartialContent) {
_, err = io.Copy(io.Discard, forwardResp.Body)
if err != nil {
return err
}
return fmt.Errorf("expected mirror to respond with 200 OK but received: %s", forwardResp.Status)
defer httpx.DrainAndClose(forwardResp.Body)
err = httpx.CheckResponseStatus(forwardResp, http.StatusOK, http.StatusPartialContent)
if err != nil {
return err
}
// TODO (phillebaba): Is it possible to retry if copy fails half way through?