diff --git a/endpoint/base.go b/endpoint/base.go index 0b860d9..192139c 100644 --- a/endpoint/base.go +++ b/endpoint/base.go @@ -42,14 +42,30 @@ func (b *BaseEndpoint[T]) GetByID(id uint64) (*T, error) { } func (b *BaseEndpoint[T]) GetByIDs(ids []uint64) ([]*T, error) { - builder := strings.Builder{} - for i, v := range ids { - if i > 0 { - builder.WriteByte(',') - } - builder.WriteString(strconv.FormatUint(v, 10)) + if len(ids) == 0 { + return nil, fmt.Errorf("ids cant be empty") } - return b.Query(fmt.Sprintf("where id = (%s); fields *;", builder.String())) + batches := make([][]uint64, 0) + for i := 0; i < len(ids); i += 500 { + end := min(i+500, len(ids)) + batches = append(batches, ids[i:end]) + } + res := []*T{} + for _, batch := range batches { + builder := strings.Builder{} + for i, v := range batch { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(strconv.FormatUint(v, 10)) + } + batchRes, err := b.Query(fmt.Sprintf("where id = (%s); fields *; limit 500;", builder.String())) + if err != nil { + return nil, err + } + res = append(res, batchRes...) + } + return res, nil } func (b *BaseEndpoint[T]) Count() (uint64, error) {