Enforce use of request contexts and fix response closing

Signed-off-by: Philip Laine <philip.laine@gmail.com>
This commit is contained in:
Philip Laine 2025-06-05 16:10:24 +02:00
parent de24996538
commit 153e54ecba
No known key found for this signature in database
GPG Key ID: F6D0B743CA3EFF33
9 changed files with 49 additions and 22 deletions

View File

@ -15,6 +15,7 @@ linters:
- staticcheck
- testifylint
- unused
- noctx
settings:
errcheck:
disable-default-exclusions: true

View File

@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- [#911](https://github.com/spegel-org/spegel/pull/911) Enforce use of request contexts and fix response closing.
### Security
## v0.3.0

View File

@ -23,12 +23,13 @@ import (
var templatesFS embed.FS
type Web struct {
router routing.Router
client *oci.Client
tmpls *template.Template
router routing.Router
ociClient *oci.Client
httpClient *http.Client
tmpls *template.Template
}
func NewWeb(router routing.Router) (*Web, error) {
func NewWeb(router routing.Router, ociClient *oci.Client) (*Web, error) {
funcs := template.FuncMap{
"formatBytes": formatBytes,
"formatDuration": formatDuration,
@ -38,9 +39,10 @@ func NewWeb(router routing.Router) (*Web, error) {
return nil, err
}
return &Web{
router: router,
client: oci.NewClient(),
tmpls: tmpls,
router: router,
ociClient: ociClient,
httpClient: httpx.BaseClient(),
tmpls: tmpls,
}, nil
}
@ -63,12 +65,18 @@ func (w *Web) indexHandler(rw httpx.ResponseWriter, req *http.Request) {
func (w *Web) statsHandler(rw httpx.ResponseWriter, req *http.Request) {
//nolint: errcheck // Ignore error.
srvAddr := req.Context().Value(http.LocalAddrContextKey).(net.Addr)
resp, err := http.Get(fmt.Sprintf("http://%s/metrics", srvAddr.String()))
req, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fmt.Sprintf("http://%s/metrics", srvAddr.String()), nil)
if err != nil {
rw.WriteError(http.StatusInternalServerError, err)
return
}
defer resp.Body.Close()
resp, err := w.httpClient.Do(req)
if err != nil {
rw.WriteError(http.StatusInternalServerError, err)
return
}
defer httpx.DrainAndClose(resp.Body)
parser := expfmt.TextParser{}
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
@ -151,7 +159,7 @@ func (w *Web) measureHandler(rw httpx.ResponseWriter, req *http.Request) {
if len(res.PeerResults) > 0 {
// Pull the image and measure performance.
pullMetrics, err := w.client.Pull(req.Context(), img, mirror)
pullMetrics, err := w.ociClient.Pull(req.Context(), img, mirror)
if err != nil {
rw.WriteError(http.StatusInternalServerError, err)
return

View File

@ -10,7 +10,7 @@ import (
func TestWeb(t *testing.T) {
t.Parallel()
w, err := NewWeb(nil)
w, err := NewWeb(nil, nil)
require.NoError(t, err)
require.NotNil(t, w.tmpls)
}

View File

@ -141,6 +141,8 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
return err
}
ociClient := oci.NewClient()
// OCI Store
ociStore, err := oci.NewContainerd(args.ContainerdSock, args.ContainerdNamespace, args.ContainerdRegistryConfigPath, args.MirroredRegistries, oci.WithContentPath(args.ContainerdContentPath))
if err != nil {
@ -209,7 +211,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
return regSrv.Shutdown(shutdownCtx)
})
// Metrics
// Metrics, pprof, and debug web
metrics.Register()
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(metrics.DefaultGatherer, promhttp.HandlerOpts{}))
@ -224,7 +226,7 @@ func registryCommand(ctx context.Context, args *RegistryCmd) (err error) {
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
if args.DebugWebEnabled {
web, err := web.NewWeb(router)
web, err := web.NewWeb(router, ociClient)
if err != nil {
return err
}

View File

@ -44,7 +44,6 @@ func CheckResponseStatus(resp *http.Response, expectedCodes ...int) error {
}
func getErrorMessage(resp *http.Response) (string, error) {
defer resp.Body.Close()
if resp.Request.Method == http.MethodHead {
return "", nil
}

View File

@ -209,6 +209,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath
}
err = httpx.CheckResponseStatus(resp, http.StatusOK, http.StatusPartialContent)
if err != nil {
httpx.DrainAndClose(resp.Body)
return nil, ocispec.Descriptor{}, err
}
@ -217,6 +218,7 @@ func (c *Client) fetch(ctx context.Context, method string, dist DistributionPath
}
desc, err := DescriptorFromHeader(resp.Header)
if err != nil {
httpx.DrainAndClose(resp.Body)
return nil, ocispec.Descriptor{}, err
}
return resp.Body, desc, nil
@ -257,11 +259,11 @@ func getBearerToken(ctx context.Context, wwwAuth string, client *http.Client) (s
if err != nil {
return "", err
}
defer httpx.DrainAndClose(resp.Body)
err = httpx.CheckResponseStatus(resp, http.StatusOK)
if err != nil {
return "", err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", err

View File

@ -12,6 +12,7 @@ import (
"github.com/go-logr/logr"
"github.com/stretchr/testify/require"
"github.com/spegel-org/spegel/pkg/httpx"
"github.com/spegel-org/spegel/pkg/oci"
"github.com/spegel-org/spegel/pkg/routing"
)
@ -236,7 +237,7 @@ func TestMirrorHandler(t *testing.T) {
srv.Handler.ServeHTTP(rw, req)
resp := rw.Result()
defer resp.Body.Close()
defer httpx.DrainAndClose(resp.Body)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, tt.expectedStatus, resp.StatusCode)

View File

@ -16,6 +16,8 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/spegel-org/spegel/pkg/httpx"
)
// Bootstrapper resolves peers to bootstrap with for the P2P router.
@ -120,14 +122,16 @@ func (b *DNSBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) {
var _ Bootstrapper = &HTTPBootstrapper{}
type HTTPBootstrapper struct {
addr string
peer string
httpClient *http.Client
addr string
peer string
}
func NewHTTPBootstrapper(addr, peer string) *HTTPBootstrapper {
return &HTTPBootstrapper{
addr: addr,
peer: peer,
httpClient: httpx.BaseClient(),
addr: addr,
peer: peer,
}
}
@ -159,11 +163,19 @@ func (bs *HTTPBootstrapper) Run(ctx context.Context, id string) error {
}
func (bs *HTTPBootstrapper) Get(ctx context.Context) ([]peer.AddrInfo, error) {
resp, err := http.DefaultClient.Get(bs.peer)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, bs.peer, nil)
if err != nil {
return nil, err
}
resp, err := bs.httpClient.Do(req)
if err != nil {
return nil, err
}
defer httpx.DrainAndClose(resp.Body)
err = httpx.CheckResponseStatus(resp, http.StatusOK)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err