diff --git a/CHANGELOG.md b/CHANGELOG.md index 86b3237..d164a64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - [#906](https://github.com/spegel-org/spegel/pull/906) Replace HTTP header strings with httpx constants. +- [#916](https://github.com/spegel-org/spegel/pull/916) Refactor OCI client options and add header configuration. ### Deprecated diff --git a/internal/web/web.go b/internal/web/web.go index b3b3a6f..76b6cca 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -159,7 +159,7 @@ func (w *Web) measureHandler(rw httpx.ResponseWriter, req *http.Request) { if len(res.PeerResults) > 0 { // Pull the image and measure performance. - pullMetrics, err := w.ociClient.Pull(req.Context(), img, mirror) + pullMetrics, err := w.ociClient.Pull(req.Context(), img, oci.WithFetchMirror(mirror)) if err != nil { rw.WriteError(http.StatusInternalServerError, err) return diff --git a/pkg/httpx/header.go b/pkg/httpx/header.go new file mode 100644 index 0000000..c42321c --- /dev/null +++ b/pkg/httpx/header.go @@ -0,0 +1,30 @@ +package httpx + +import "net/http" + +const ( + HeaderContentType = "Content-Type" + HeaderContentLength = "Content-Length" + HeaderContentRange = "Content-Range" + HeaderRange = "Range" + HeaderAcceptRanges = "Accept-Ranges" + HeaderUserAgent = "User-Agent" + HeaderAccept = "Accept" + HeaderAuthorization = "Authorization" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" +) + +const ( + ContentTypeBinary = "application/octet-stream" + ContentTypeJSON = "application/json" +) + +// CopyHeader copies header from source to destination. +func CopyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} diff --git a/pkg/httpx/header_test.go b/pkg/httpx/header_test.go new file mode 100644 index 0000000..270f085 --- /dev/null +++ b/pkg/httpx/header_test.go @@ -0,0 +1,20 @@ +package httpx + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCopyHeader(t *testing.T) { + t.Parallel() + + src := http.Header{ + "foo": []string{"2", "1"}, + } + dst := http.Header{} + CopyHeader(dst, src) + + require.Equal(t, []string{"2", "1"}, dst.Values("foo")) +} diff --git a/pkg/httpx/httpx.go b/pkg/httpx/httpx.go index 198667f..bfcc293 100644 --- a/pkg/httpx/httpx.go +++ b/pkg/httpx/httpx.go @@ -8,24 +8,6 @@ import ( "time" ) -const ( - HeaderContentType = "Content-Type" - HeaderContentLength = "Content-Length" - HeaderContentRange = "Content-Range" - HeaderRange = "Range" - HeaderAcceptRanges = "Accept-Ranges" - HeaderUserAgent = "User-Agent" - HeaderAccept = "Accept" - HeaderAuthorization = "Authorization" - HeaderWWWAuthenticate = "WWW-Authenticate" - HeaderXForwardedFor = "X-Forwarded-For" -) - -const ( - ContentTypeBinary = "application/octet-stream" - ContentTypeJSON = "application/json" -) - // BaseClient returns a http client with reasonable defaults set. func BaseClient() *http.Client { return &http.Client{ diff --git a/pkg/oci/client.go b/pkg/oci/client.go index 42dfafa..154d83f 100644 --- a/pkg/oci/client.go +++ b/pkg/oci/client.go @@ -25,6 +25,39 @@ const ( HeaderDockerDigest = "Docker-Content-Digest" ) +type FetchConfig struct { + Mirror *url.URL + Header http.Header +} + +func (cfg *FetchConfig) Apply(opts ...FetchOption) error { + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt(cfg); err != nil { + return err + } + } + return nil +} + +type FetchOption func(cfg *FetchConfig) error + +func WithFetchMirror(mirror *url.URL) FetchOption { + return func(cfg *FetchConfig) error { + cfg.Mirror = mirror + return nil + } +} + +func WithFetchHeader(header http.Header) FetchOption { + return func(cfg *FetchConfig) error { + cfg.Header = header + return nil + } +} + type Client struct { hc *http.Client tc sync.Map @@ -46,7 +79,7 @@ type PullMetric struct { Duration time.Duration } -func (c *Client) Pull(ctx context.Context, img Image, mirror *url.URL) ([]PullMetric, error) { +func (c *Client) Pull(ctx context.Context, img Image, opts ...FetchOption) ([]PullMetric, error) { pullMetrics := []PullMetric{} queue := []DistributionPath{ @@ -64,7 +97,7 @@ func (c *Client) Pull(ctx context.Context, img Image, mirror *url.URL) ([]PullMe start := time.Now() desc, err := func() (ocispec.Descriptor, error) { - rc, desc, err := c.Get(ctx, dist, mirror, nil) + rc, desc, err := c.Get(ctx, dist, nil, opts...) if err != nil { return ocispec.Descriptor{}, err } @@ -145,8 +178,8 @@ func (c *Client) Pull(ctx context.Context, img Image, mirror *url.URL) ([]PullMe return pullMetrics, nil } -func (c *Client) Head(ctx context.Context, dist DistributionPath, mirror *url.URL) (ocispec.Descriptor, error) { - rc, desc, err := c.fetch(ctx, http.MethodHead, dist, mirror, nil) +func (c *Client) Head(ctx context.Context, dist DistributionPath, opts ...FetchOption) (ocispec.Descriptor, error) { + rc, desc, err := c.fetch(ctx, http.MethodHead, dist, nil, opts...) if err != nil { return ocispec.Descriptor{}, err } @@ -154,22 +187,28 @@ func (c *Client) Head(ctx context.Context, dist DistributionPath, mirror *url.UR return desc, nil } -func (c *Client) Get(ctx context.Context, dist DistributionPath, mirror *url.URL, brr []httpx.ByteRange) (io.ReadCloser, ocispec.Descriptor, error) { - rc, desc, err := c.fetch(ctx, http.MethodGet, dist, mirror, brr) +func (c *Client) Get(ctx context.Context, dist DistributionPath, brr []httpx.ByteRange, opts ...FetchOption) (io.ReadCloser, ocispec.Descriptor, error) { + rc, desc, err := c.fetch(ctx, http.MethodGet, dist, brr, opts...) if err != nil { return nil, ocispec.Descriptor{}, err } return rc, desc, nil } -func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath, mirror *url.URL, brr []httpx.ByteRange) (io.ReadCloser, ocispec.Descriptor, error) { +func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath, brr []httpx.ByteRange, opts ...FetchOption) (io.ReadCloser, ocispec.Descriptor, error) { + cfg := FetchConfig{} + err := cfg.Apply(opts...) + if err != nil { + return nil, ocispec.Descriptor{}, err + } + tcKey := dist.Registry + dist.Name u := dist.URL() - if mirror != nil { - u.Scheme = mirror.Scheme - u.Host = mirror.Host - u.Path = path.Join(mirror.Path, u.Path) + if cfg.Mirror != nil { + u.Scheme = cfg.Mirror.Scheme + u.Host = cfg.Mirror.Host + u.Path = path.Join(cfg.Mirror.Path, u.Path) } if u.Host == "docker.io" { u.Host = "registry-1.docker.io" @@ -180,6 +219,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath if err != nil { return nil, ocispec.Descriptor{}, err } + httpx.CopyHeader(req.Header, cfg.Header) req.Header.Set(httpx.HeaderUserAgent, "spegel") req.Header.Add(httpx.HeaderAccept, "application/vnd.oci.image.manifest.v1+json") req.Header.Add(httpx.HeaderAccept, "application/vnd.docker.distribution.manifest.v2+json") diff --git a/pkg/oci/client_test.go b/pkg/oci/client_test.go index 610313a..28f7a3b 100644 --- a/pkg/oci/client_test.go +++ b/pkg/oci/client_test.go @@ -65,7 +65,7 @@ func TestClient(t *testing.T) { client := NewClient() mirror, err := url.Parse(srv.URL) require.NoError(t, err) - pullResults, err := client.Pull(t.Context(), img, mirror) + pullResults, err := client.Pull(t.Context(), img, WithFetchMirror(mirror)) require.NoError(t, err) require.Len(t, pullResults, 3) @@ -74,7 +74,7 @@ func TestClient(t *testing.T) { Name: img.Repository, Digest: blobs[0].Digest, } - desc, err := client.Head(t.Context(), dist, mirror) + desc, err := client.Head(t.Context(), dist, WithFetchMirror(mirror)) require.NoError(t, err) require.Equal(t, dist.Digest, desc.Digest) require.Equal(t, httpx.ContentTypeBinary, desc.MediaType) diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index ba92885..c37926d 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -359,7 +359,7 @@ func forwardRequest(client *http.Client, bufferPool *sync.Pool, req *http.Reques if err != nil { return err } - copyHeader(forwardReq.Header, req.Header) + httpx.CopyHeader(forwardReq.Header, req.Header) forwardResp, err := client.Do(forwardReq) if err != nil { return err @@ -372,7 +372,7 @@ func forwardRequest(client *http.Client, bufferPool *sync.Pool, req *http.Reques // TODO (phillebaba): Is it possible to retry if copy fails half way through? // Copy forward response to response writer. - copyHeader(rw.Header(), forwardResp.Header) + httpx.CopyHeader(rw.Header(), forwardResp.Header) rw.WriteHeader(http.StatusOK) //nolint: errcheck // Ignore buf := bufferPool.Get().(*[]byte) @@ -383,11 +383,3 @@ func forwardRequest(client *http.Client, bufferPool *sync.Pool, req *http.Reques } return nil } - -func copyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 1ce7f04..d4ca5f9 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -259,15 +259,3 @@ func TestMirrorHandler(t *testing.T) { } } } - -func TestCopyHeader(t *testing.T) { - t.Parallel() - - src := http.Header{ - "foo": []string{"2", "1"}, - } - dst := http.Header{} - copyHeader(dst, src) - - require.Equal(t, []string{"2", "1"}, dst.Values("foo")) -}