diff --git a/.golangci.yaml b/.golangci.yaml index cc0da26..de608d6 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -15,6 +15,7 @@ linters: - staticcheck - testifylint - unused + - noctx settings: errcheck: disable-default-exclusions: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 25cf6ab..8f49cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- [#911](https://github.com/spegel-org/spegel/pull/911) Enforce use of request contexts and fix response closing. + ### Security ## v0.3.0 diff --git a/internal/web/web.go b/internal/web/web.go index 744751a..b3b3a6f 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -23,12 +23,13 @@ import ( var templatesFS embed.FS type Web struct { - router routing.Router - client *oci.Client - tmpls *template.Template + router routing.Router + ociClient *oci.Client + httpClient *http.Client + tmpls *template.Template } -func NewWeb(router routing.Router) (*Web, error) { +func NewWeb(router routing.Router, ociClient *oci.Client) (*Web, error) { funcs := template.FuncMap{ "formatBytes": formatBytes, "formatDuration": formatDuration, @@ -38,9 +39,10 @@ func NewWeb(router routing.Router) (*Web, error) { return nil, err } return &Web{ - router: router, - client: oci.NewClient(), - tmpls: tmpls, + router: router, + ociClient: ociClient, + httpClient: httpx.BaseClient(), + tmpls: tmpls, }, nil } @@ -63,12 +65,18 @@ func (w *Web) indexHandler(rw httpx.ResponseWriter, req *http.Request) { func (w *Web) statsHandler(rw httpx.ResponseWriter, req *http.Request) { //nolint: errcheck // Ignore error. srvAddr := req.Context().Value(http.LocalAddrContextKey).(net.Addr) - resp, err := http.Get(fmt.Sprintf("http://%s/metrics", srvAddr.String())) + req, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fmt.Sprintf("http://%s/metrics", srvAddr.String()), nil) if err != nil { rw.WriteError(http.StatusInternalServerError, err) return } - defer resp.Body.Close() + resp, err := w.httpClient.Do(req) + if err != nil { + rw.WriteError(http.StatusInternalServerError, err) + return + } + defer httpx.DrainAndClose(resp.Body) + parser := expfmt.TextParser{} metricFamilies, err := parser.TextToMetricFamilies(resp.Body) if err != nil { @@ -151,7 +159,7 @@ func (w *Web) measureHandler(rw httpx.ResponseWriter, req *http.Request) { if len(res.PeerResults) > 0 { // Pull the image and measure performance. - pullMetrics, err := w.client.Pull(req.Context(), img, mirror) + pullMetrics, err := w.ociClient.Pull(req.Context(), img, mirror) if err != nil { rw.WriteError(http.StatusInternalServerError, err) return diff --git a/internal/web/web_test.go b/internal/web/web_test.go index 9cdff86..a0eaf95 100644 --- a/internal/web/web_test.go +++ b/internal/web/web_test.go @@ -10,7 +10,7 @@ import ( func TestWeb(t *testing.T) { t.Parallel() - w, err := NewWeb(nil) + w, err := NewWeb(nil, nil) require.NoError(t, err) require.NotNil(t, w.tmpls) } diff --git a/main.go b/main.go index 3a2d017..7508d2c 100644 --- a/main.go +++ b/main.go @@ -141,6 +141,8 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) { return err } + ociClient := oci.NewClient() + // OCI Store ociStore, err := oci.NewContainerd(args.ContainerdSock, args.ContainerdNamespace, args.ContainerdRegistryConfigPath, args.MirroredRegistries, oci.WithContentPath(args.ContainerdContentPath)) if err != nil { @@ -209,7 +211,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) { return regSrv.Shutdown(shutdownCtx) }) - // Metrics + // Metrics, pprof, and debug web metrics.Register() mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(metrics.DefaultGatherer, promhttp.HandlerOpts{})) @@ -224,7 +226,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) { mux.Handle("/debug/pprof/block", pprof.Handler("block")) mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) if args.DebugWebEnabled { - web, err := web.NewWeb(router) + web, err := web.NewWeb(router, ociClient) if err != nil { return err } diff --git a/pkg/httpx/status.go b/pkg/httpx/status.go index 4ebef42..bdec023 100644 --- a/pkg/httpx/status.go +++ b/pkg/httpx/status.go @@ -44,7 +44,6 @@ func CheckResponseStatus(resp *http.Response, expectedCodes ...int) error { } func getErrorMessage(resp *http.Response) (string, error) { - defer resp.Body.Close() if resp.Request.Method == http.MethodHead { return "", nil } diff --git a/pkg/oci/client.go b/pkg/oci/client.go index a1a5cfa..adba660 100644 --- a/pkg/oci/client.go +++ b/pkg/oci/client.go @@ -209,6 +209,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath } err = httpx.CheckResponseStatus(resp, http.StatusOK, http.StatusPartialContent) if err != nil { + httpx.DrainAndClose(resp.Body) return nil, ocispec.Descriptor{}, err } @@ -217,6 +218,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath } desc, err := DescriptorFromHeader(resp.Header) if err != nil { + httpx.DrainAndClose(resp.Body) return nil, ocispec.Descriptor{}, err } return resp.Body, desc, nil @@ -257,11 +259,11 @@ func getBearerToken(ctx context.Context, wwwAuth string, client *http.Client) (s if err != nil { return "", err } + defer httpx.DrainAndClose(resp.Body) err = httpx.CheckResponseStatus(resp, http.StatusOK) if err != nil { return "", err } - defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return "", err diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go index 0e2e771..1ce7f04 100644 --- a/pkg/registry/registry_test.go +++ b/pkg/registry/registry_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-logr/logr" "github.com/stretchr/testify/require" + "github.com/spegel-org/spegel/pkg/httpx" "github.com/spegel-org/spegel/pkg/oci" "github.com/spegel-org/spegel/pkg/routing" ) @@ -236,7 +237,7 @@ func TestMirrorHandler(t *testing.T) { srv.Handler.ServeHTTP(rw, req) resp := rw.Result() - defer resp.Body.Close() + defer httpx.DrainAndClose(resp.Body) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, tt.expectedStatus, resp.StatusCode) diff --git a/pkg/routing/bootstrap.go b/pkg/routing/bootstrap.go index ab1e952..c08a309 100644 --- a/pkg/routing/bootstrap.go +++ b/pkg/routing/bootstrap.go @@ -16,6 +16,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" + + "github.com/spegel-org/spegel/pkg/httpx" ) // Bootstrapper resolves peers to bootstrap with for the P2P router. @@ -120,14 +122,16 @@ func (b *DNSBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) { var _ Bootstrapper = &HTTPBootstrapper{} type HTTPBootstrapper struct { - addr string - peer string + httpClient *http.Client + addr string + peer string } func NewHTTPBootstrapper(addr, peer string) *HTTPBootstrapper { return &HTTPBootstrapper{ - addr: addr, - peer: peer, + httpClient: httpx.BaseClient(), + addr: addr, + peer: peer, } } @@ -159,11 +163,19 @@ func (bs *HTTPBootstrapper) Run(ctx context.Context, id string) error { } func (bs *HTTPBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) { - resp, err := http.DefaultClient.Get(bs.peer) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, bs.peer, nil) + if err != nil { + return nil, err + } + resp, err := bs.httpClient.Do(req) + if err != nil { + return nil, err + } + defer httpx.DrainAndClose(resp.Body) + err = httpx.CheckResponseStatus(resp, http.StatusOK) if err != nil { return nil, err } - defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return nil, err