From aaf697a005be090d59fb6562873aff20254ae097 Mon Sep 17 00:00:00 2001 From: nite Date: Mon, 3 Nov 2025 18:25:53 +1100 Subject: [PATCH] refactor: Improve endpoint robustness and client clarity Refactored client request parameter naming and enhanced endpoint methods for better validation and idiomatic behavior. - Renamed `URL` parameter to `requestURL` in `Client.Request` for improved clarity and to avoid potential naming conflicts. - Added validation to `BaseEndpoint.GetByID` to prevent queries with an ID of 0, returning an error. - Modified `BaseEndpoint.GetByIDs` to return an empty slice (`[]*T{}`) instead of an error when no IDs are provided, aligning with common Go idioms for empty result sets. - Enhanced `BaseEndpoint.Count` method to return an error if the API reports a count of 0, ensuring that a successful count operation always yields a positive result. --- client.go | 6 +++--- endpoint/base.go | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 102340f..c8d7384 100644 --- a/client.go +++ b/client.go @@ -113,7 +113,7 @@ func NewWithFlaresolverr(clientID, clientSecret string, f *flaresolverr.Flaresol return c } -func (g *Client) Request(ctx context.Context, method string, URL string, dataBody any) (*resty.Response, error) { +func (g *Client) Request(ctx context.Context, method string, requestURL string, dataBody any) (*resty.Response, error) { err := g.limiter.Wait(ctx) if err != nil { return nil, fmt.Errorf("failed to get rate limiter token: %w", err) @@ -129,14 +129,14 @@ func (g *Client) Request(ctx context.Context, method string, URL string, dataBod "Authorization": "Bearer " + t, "User-Agent": "", "Content-Type": "text/plain", - }).Execute(strings.ToUpper(method), URL) + }).Execute(strings.ToUpper(method), requestURL) if resp.StatusCode() != 200 { return nil, fmt.Errorf("failed to request, expected 200 but got: %v", resp.StatusCode()) } if err != nil { - return nil, fmt.Errorf("failed to request: %s: %w", URL, err) + return nil, fmt.Errorf("failed to request: %s: %w", requestURL, err) } return resp, nil } diff --git a/endpoint/base.go b/endpoint/base.go index 494149f..bf4c75f 100644 --- a/endpoint/base.go +++ b/endpoint/base.go @@ -2,6 +2,7 @@ package endpoint import ( "context" + "errors" "fmt" "strconv" "strings" @@ -32,6 +33,9 @@ func (b *BaseEndpoint[T]) Query(ctx context.Context, query string) ([]*T, error) } func (b *BaseEndpoint[T]) GetByID(ctx context.Context, id uint64) (*T, error) { + if id == 0 { + return nil, errors.New("id cant be 0") + } res, err := b.Query(ctx, fmt.Sprintf("where id = %d; fields *;", id)) if err != nil { return nil, err @@ -44,7 +48,7 @@ func (b *BaseEndpoint[T]) GetByID(ctx context.Context, id uint64) (*T, error) { func (b *BaseEndpoint[T]) GetByIDs(ctx context.Context, ids []uint64) ([]*T, error) { if len(ids) == 0 { - return nil, fmt.Errorf("ids cant be empty") + return []*T{}, nil } batches := make([][]uint64, 0) for i := 0; i < len(ids); i += 500 { @@ -80,7 +84,11 @@ func (b *BaseEndpoint[T]) Count(ctx context.Context) (uint64, error) { return 0, fmt.Errorf("failed to unmarshal: %w", err) } - return uint64(res.Count), nil + if res.Count > 0 { + return uint64(res.Count), nil + } else { + return 0, fmt.Errorf("failed to count, count should larger than 0, but got %v", res.Count) + } } func (b *BaseEndpoint[T]) Paginated(ctx context.Context, offset, limit uint64) ([]*T, error) {