Enforce use of request contexts and fix response closing

Signed-off-by: Philip Laine <philip.laine@gmail.com>
This commit is contained in:
Philip Laine 2025-06-05 16:10:24 +02:00
parent de24996538
commit 153e54ecba
No known key found for this signature in database
GPG Key ID: F6D0B743CA3EFF33
9 changed files with 49 additions and 22 deletions

View File

@ -15,6 +15,7 @@ linters:
- staticcheck - staticcheck
- testifylint - testifylint
- unused - unused
- noctx
settings: settings:
errcheck: errcheck:
disable-default-exclusions: true disable-default-exclusions: true

View File

@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- [#911](https://github.com/spegel-org/spegel/pull/911) Enforce use of request contexts and fix response closing.
### Security ### Security
## v0.3.0 ## v0.3.0

View File

@ -23,12 +23,13 @@ import (
var templatesFS embed.FS var templatesFS embed.FS
type Web struct { type Web struct {
router routing.Router router routing.Router
client *oci.Client ociClient *oci.Client
tmpls *template.Template httpClient *http.Client
tmpls *template.Template
} }
func NewWeb(router routing.Router) (*Web, error) { func NewWeb(router routing.Router, ociClient *oci.Client) (*Web, error) {
funcs := template.FuncMap{ funcs := template.FuncMap{
"formatBytes": formatBytes, "formatBytes": formatBytes,
"formatDuration": formatDuration, "formatDuration": formatDuration,
@ -38,9 +39,10 @@ func NewWeb(router routing.Router) (*Web, error) {
return nil, err return nil, err
} }
return &Web{ return &Web{
router: router, router: router,
client: oci.NewClient(), ociClient: ociClient,
tmpls: tmpls, httpClient: httpx.BaseClient(),
tmpls: tmpls,
}, nil }, nil
} }
@ -63,12 +65,18 @@ func (w *Web) indexHandler(rw httpx.ResponseWriter, req *http.Request) {
func (w *Web) statsHandler(rw httpx.ResponseWriter, req *http.Request) { func (w *Web) statsHandler(rw httpx.ResponseWriter, req *http.Request) {
//nolint: errcheck // Ignore error. //nolint: errcheck // Ignore error.
srvAddr := req.Context().Value(http.LocalAddrContextKey).(net.Addr) srvAddr := req.Context().Value(http.LocalAddrContextKey).(net.Addr)
resp, err := http.Get(fmt.Sprintf("http://%s/metrics", srvAddr.String())) req, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fmt.Sprintf("http://%s/metrics", srvAddr.String()), nil)
if err != nil { if err != nil {
rw.WriteError(http.StatusInternalServerError, err) rw.WriteError(http.StatusInternalServerError, err)
return return
} }
defer resp.Body.Close() resp, err := w.httpClient.Do(req)
if err != nil {
rw.WriteError(http.StatusInternalServerError, err)
return
}
defer httpx.DrainAndClose(resp.Body)
parser := expfmt.TextParser{} parser := expfmt.TextParser{}
metricFamilies, err := parser.TextToMetricFamilies(resp.Body) metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil { if err != nil {
@ -151,7 +159,7 @@ func (w *Web) measureHandler(rw httpx.ResponseWriter, req *http.Request) {
if len(res.PeerResults) > 0 { if len(res.PeerResults) > 0 {
// Pull the image and measure performance. // Pull the image and measure performance.
pullMetrics, err := w.client.Pull(req.Context(), img, mirror) pullMetrics, err := w.ociClient.Pull(req.Context(), img, mirror)
if err != nil { if err != nil {
rw.WriteError(http.StatusInternalServerError, err) rw.WriteError(http.StatusInternalServerError, err)
return return

View File

@ -10,7 +10,7 @@ import (
func TestWeb(t *testing.T) { func TestWeb(t *testing.T) {
t.Parallel() t.Parallel()
w, err := NewWeb(nil) w, err := NewWeb(nil, nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, w.tmpls) require.NotNil(t, w.tmpls)
} }

View File

@ -141,6 +141,8 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
return err return err
} }
ociClient := oci.NewClient()
// OCI Store // OCI Store
ociStore, err := oci.NewContainerd(args.ContainerdSock, args.ContainerdNamespace, args.ContainerdRegistryConfigPath, args.MirroredRegistries, oci.WithContentPath(args.ContainerdContentPath)) ociStore, err := oci.NewContainerd(args.ContainerdSock, args.ContainerdNamespace, args.ContainerdRegistryConfigPath, args.MirroredRegistries, oci.WithContentPath(args.ContainerdContentPath))
if err != nil { if err != nil {
@ -209,7 +211,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
return regSrv.Shutdown(shutdownCtx) return regSrv.Shutdown(shutdownCtx)
}) })
// Metrics // Metrics, pprof, and debug web
metrics.Register() metrics.Register()
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(metrics.DefaultGatherer, promhttp.HandlerOpts{})) mux.Handle("/metrics", promhttp.HandlerFor(metrics.DefaultGatherer, promhttp.HandlerOpts{}))
@ -224,7 +226,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
mux.Handle("/debug/pprof/block", pprof.Handler("block")) mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
if args.DebugWebEnabled { if args.DebugWebEnabled {
web, err := web.NewWeb(router) web, err := web.NewWeb(router, ociClient)
if err != nil { if err != nil {
return err return err
} }

View File

@ -44,7 +44,6 @@ func CheckResponseStatus(resp *http.Response, expectedCodes ...int) error {
} }
func getErrorMessage(resp *http.Response) (string, error) { func getErrorMessage(resp *http.Response) (string, error) {
defer resp.Body.Close()
if resp.Request.Method == http.MethodHead { if resp.Request.Method == http.MethodHead {
return "", nil return "", nil
} }

View File

@ -209,6 +209,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath
} }
err = httpx.CheckResponseStatus(resp, http.StatusOK, http.StatusPartialContent) err = httpx.CheckResponseStatus(resp, http.StatusOK, http.StatusPartialContent)
if err != nil { if err != nil {
httpx.DrainAndClose(resp.Body)
return nil, ocispec.Descriptor{}, err return nil, ocispec.Descriptor{}, err
} }
@ -217,6 +218,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath
} }
desc, err := DescriptorFromHeader(resp.Header) desc, err := DescriptorFromHeader(resp.Header)
if err != nil { if err != nil {
httpx.DrainAndClose(resp.Body)
return nil, ocispec.Descriptor{}, err return nil, ocispec.Descriptor{}, err
} }
return resp.Body, desc, nil return resp.Body, desc, nil
@ -257,11 +259,11 @@ func getBearerToken(ctx context.Context, wwwAuth string, client *http.Client) (s
if err != nil { if err != nil {
return "", err return "", err
} }
defer httpx.DrainAndClose(resp.Body)
err = httpx.CheckResponseStatus(resp, http.StatusOK) err = httpx.CheckResponseStatus(resp, http.StatusOK)
if err != nil { if err != nil {
return "", err return "", err
} }
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -12,6 +12,7 @@ import (
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/spegel-org/spegel/pkg/httpx"
"github.com/spegel-org/spegel/pkg/oci" "github.com/spegel-org/spegel/pkg/oci"
"github.com/spegel-org/spegel/pkg/routing" "github.com/spegel-org/spegel/pkg/routing"
) )
@ -236,7 +237,7 @@ func TestMirrorHandler(t *testing.T) {
srv.Handler.ServeHTTP(rw, req) srv.Handler.ServeHTTP(rw, req)
resp := rw.Result() resp := rw.Result()
defer resp.Body.Close() defer httpx.DrainAndClose(resp.Body)
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tt.expectedStatus, resp.StatusCode) require.Equal(t, tt.expectedStatus, resp.StatusCode)

View File

@ -16,6 +16,8 @@ import (
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
"github.com/spegel-org/spegel/pkg/httpx"
) )
// Bootstrapper resolves peers to bootstrap with for the P2P router. // Bootstrapper resolves peers to bootstrap with for the P2P router.
@ -120,14 +122,16 @@ func (b *DNSBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) {
var _ Bootstrapper = &HTTPBootstrapper{} var _ Bootstrapper = &HTTPBootstrapper{}
type HTTPBootstrapper struct { type HTTPBootstrapper struct {
addr string httpClient *http.Client
peer string addr string
peer string
} }
func NewHTTPBootstrapper(addr, peer string) *HTTPBootstrapper { func NewHTTPBootstrapper(addr, peer string) *HTTPBootstrapper {
return &HTTPBootstrapper{ return &HTTPBootstrapper{
addr: addr, httpClient: httpx.BaseClient(),
peer: peer, addr: addr,
peer: peer,
} }
} }
@ -159,11 +163,19 @@ func (bs *HTTPBootstrapper) Run(ctx context.Context, id string) error {
} }
func (bs *HTTPBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) { func (bs *HTTPBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) {
resp, err := http.DefaultClient.Get(bs.peer) req, err := http.NewRequestWithContext(ctx, http.MethodGet, bs.peer, nil)
if err != nil {
return nil, err
}
resp, err := bs.httpClient.Do(req)
if err != nil {
return nil, err
}
defer httpx.DrainAndClose(resp.Body)
err = httpx.CheckResponseStatus(resp, http.StatusOK)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err