Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit ae9cb5f

Browse files
committed
refactor(hamt): remove child interface from hamt pkg
1 parent e3cca8a commit ae9cb5f

File tree

1 file changed

+61
-50
lines changed

1 file changed

+61
-50
lines changed

hamt/hamt.go

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,28 @@ const (
3939
HashMurmur3 uint64 = 0x22
4040
)
4141

42+
type nodeType int
43+
44+
const (
45+
invalidNode nodeType = iota
46+
shardNode
47+
shardValueNode
48+
)
49+
50+
func (ds *Shard) nodeType() nodeType {
51+
if ds.key != "" && ds.val != nil {
52+
return shardValueNode
53+
}
54+
return shardNode
55+
}
56+
4257
// A Shard represents the HAMT. It should be initialized with NewShard().
4358
type Shard struct {
4459
nd *dag.ProtoNode
4560

4661
bitfield bitfield.Bitfield
4762

48-
children []child
63+
children []*Shard
4964

5065
tableSize int
5166
tableSizeLg2 int
@@ -57,12 +72,10 @@ type Shard struct {
5772
maxpadlen int
5873

5974
dserv ipld.DAGService
60-
}
6175

62-
// child can either be another shard, or a leaf node value
63-
type child interface {
64-
Link() (*ipld.Link, error)
65-
Label() string
76+
// leaf node
77+
key string
78+
val *ipld.Link
6679
}
6780

6881
// NewShard creates a new, empty HAMT shard with the given size.
@@ -119,7 +132,7 @@ func NewHamtFromDag(dserv ipld.DAGService, nd ipld.Node) (*Shard, error) {
119132
}
120133

121134
ds.nd = pbnd.Copy().(*dag.ProtoNode)
122-
ds.children = make([]child, len(pbnd.Links()))
135+
ds.children = make([]*Shard, len(pbnd.Links()))
123136
ds.bitfield.SetBytes(fsn.Data())
124137
ds.hashFunc = fsn.HashType()
125138
ds.builder = ds.nd.CidBuilder()
@@ -188,23 +201,9 @@ func (ds *Shard) Node() (ipld.Node, error) {
188201
return out, nil
189202
}
190203

191-
type shardValue struct {
192-
key string
193-
val *ipld.Link
194-
}
195-
196-
// Link returns a link to this node
197-
func (sv *shardValue) Link() (*ipld.Link, error) {
198-
return sv.val, nil
199-
}
200-
201-
func (sv *shardValue) Label() string {
202-
return sv.key
203-
}
204-
205-
func (ds *Shard) makeShardValue(lnk *ipld.Link) *shardValue {
204+
func (ds *Shard) makeShardValue(lnk *ipld.Link) *Shard {
206205
lnk2 := *lnk
207-
return &shardValue{
206+
return &Shard{
208207
key: lnk.Name[ds.maxpadlen:],
209208
val: &lnk2,
210209
}
@@ -219,6 +218,10 @@ func hash(val []byte) []byte {
219218
// Label for Shards is the empty string, this is used to differentiate them from
220219
// value entries
221220
func (ds *Shard) Label() string {
221+
nodeType := ds.nodeType()
222+
if nodeType == shardValueNode {
223+
return ds.key
224+
}
222225
return ""
223226
}
224227

@@ -250,7 +253,7 @@ func (ds *Shard) Find(ctx context.Context, name string) (*ipld.Link, error) {
250253
hv := &hashBits{b: hash([]byte(name))}
251254

252255
var out *ipld.Link
253-
err := ds.getValue(ctx, hv, name, func(sv *shardValue) error {
256+
err := ds.getValue(ctx, hv, name, func(sv *Shard) error {
254257
out = sv.val
255258
return nil
256259
})
@@ -282,7 +285,7 @@ func (ds *Shard) childLinkType(lnk *ipld.Link) (linkType, error) {
282285
// getChild returns the i'th child of this shard. If it is cached in the
283286
// children array, it will return it from there. Otherwise, it loads the child
284287
// node from disk.
285-
func (ds *Shard) getChild(ctx context.Context, i int) (child, error) {
288+
func (ds *Shard) getChild(ctx context.Context, i int) (*Shard, error) {
286289
if i >= len(ds.children) || i < 0 {
287290
return nil, fmt.Errorf("invalid index passed to getChild (likely corrupt bitfield)")
288291
}
@@ -301,14 +304,14 @@ func (ds *Shard) getChild(ctx context.Context, i int) (child, error) {
301304

302305
// loadChild reads the i'th child node of this shard from disk and returns it
303306
// as a 'child' interface
304-
func (ds *Shard) loadChild(ctx context.Context, i int) (child, error) {
307+
func (ds *Shard) loadChild(ctx context.Context, i int) (*Shard, error) {
305308
lnk := ds.nd.Links()[i]
306309
lnkLinkType, err := ds.childLinkType(lnk)
307310
if err != nil {
308311
return nil, err
309312
}
310313

311-
var c child
314+
var c *Shard
312315
if lnkLinkType == shardLink {
313316
nd, err := lnk.GetNode(ctx, ds.dserv)
314317
if err != nil {
@@ -328,12 +331,17 @@ func (ds *Shard) loadChild(ctx context.Context, i int) (child, error) {
328331
return c, nil
329332
}
330333

331-
func (ds *Shard) setChild(i int, c child) {
334+
func (ds *Shard) setChild(i int, c *Shard) {
332335
ds.children[i] = c
333336
}
334337

335338
// Link returns a merklelink to this shard node
336339
func (ds *Shard) Link() (*ipld.Link, error) {
340+
nodeType := ds.nodeType()
341+
if nodeType == shardValueNode {
342+
return ds.val, nil
343+
}
344+
337345
nd, err := ds.Node()
338346
if err != nil {
339347
return nil, err
@@ -356,13 +364,13 @@ func (ds *Shard) insertChild(idx int, key string, lnk *ipld.Link) error {
356364
ds.bitfield.SetBit(idx)
357365

358366
lnk.Name = ds.linkNamePrefix(idx) + key
359-
sv := &shardValue{
367+
sv := &Shard{
360368
key: key,
361369
val: lnk,
362370
}
363371

364-
ds.children = append(ds.children[:i], append([]child{sv}, ds.children[i:]...)...)
365-
ds.nd.SetLinks(append(ds.nd.Links()[:i], append([]*ipld.Link{nil}, ds.nd.Links()[i:]...)...))
372+
ds.children = append(ds.children[:i], append([]*Shard{sv}, ds.children[i:]...)...)
373+
ds.nd.SetLinks(append(ds.nd.Links()[:i], append([]*ipld.Link{lnk}, ds.nd.Links()[i:]...)...))
366374
return nil
367375
}
368376

@@ -380,7 +388,7 @@ func (ds *Shard) rmChild(i int) error {
380388
return nil
381389
}
382390

383-
func (ds *Shard) getValue(ctx context.Context, hv *hashBits, key string, cb func(*shardValue) error) error {
391+
func (ds *Shard) getValue(ctx context.Context, hv *hashBits, key string, cb func(*Shard) error) error {
384392
idx := hv.Next(ds.tableSizeLg2)
385393
if ds.bitfield.Bit(int(idx)) {
386394
cindex := ds.indexForBitPos(idx)
@@ -390,10 +398,11 @@ func (ds *Shard) getValue(ctx context.Context, hv *hashBits, key string, cb func
390398
return err
391399
}
392400

393-
switch child := child.(type) {
394-
case *Shard:
401+
childType := child.nodeType()
402+
switch childType {
403+
case shardNode:
395404
return child.getValue(ctx, hv, key, cb)
396-
case *shardValue:
405+
case shardValueNode:
397406
if child.key == key {
398407
return cb(child)
399408
}
@@ -408,7 +417,7 @@ func (ds *Shard) EnumLinks(ctx context.Context) ([]*ipld.Link, error) {
408417
var links []*ipld.Link
409418
var setlk sync.Mutex
410419

411-
getLinks := makeAsyncTrieGetLinks(ds.dserv, func(sv *shardValue) error {
420+
getLinks := makeAsyncTrieGetLinks(ds.dserv, func(sv *Shard) error {
412421
lnk := sv.val
413422
lnk.Name = sv.key
414423
setlk.Lock()
@@ -425,7 +434,7 @@ func (ds *Shard) EnumLinks(ctx context.Context) ([]*ipld.Link, error) {
425434

426435
// ForEachLink walks the Shard and calls the given function.
427436
func (ds *Shard) ForEachLink(ctx context.Context, f func(*ipld.Link) error) error {
428-
return ds.walkTrie(ctx, func(sv *shardValue) error {
437+
return ds.walkTrie(ctx, func(sv *Shard) error {
429438
lnk := sv.val
430439
lnk.Name = sv.key
431440

@@ -436,7 +445,7 @@ func (ds *Shard) ForEachLink(ctx context.Context, f func(*ipld.Link) error) erro
436445
// makeAsyncTrieGetLinks builds a getLinks function that can be used with EnumerateChildrenAsync
437446
// to iterate a HAMT shard. It takes an IPLD Dag Service to fetch nodes, and a call back that will get called
438447
// on all links to leaf nodes in a HAMT tree, so they can be collected for an EnumLinks operation
439-
func makeAsyncTrieGetLinks(dagService ipld.DAGService, onShardValue func(*shardValue) error) dag.GetLinks {
448+
func makeAsyncTrieGetLinks(dagService ipld.DAGService, onShardValue func(shard *Shard) error) dag.GetLinks {
440449

441450
return func(ctx context.Context, currentCid cid.Cid) ([]*ipld.Link, error) {
442451
node, err := dagService.Get(ctx, currentCid)
@@ -471,20 +480,22 @@ func makeAsyncTrieGetLinks(dagService ipld.DAGService, onShardValue func(*shardV
471480
}
472481
}
473482

474-
func (ds *Shard) walkTrie(ctx context.Context, cb func(*shardValue) error) error {
483+
func (ds *Shard) walkTrie(ctx context.Context, cb func(*Shard) error) error {
475484
for idx := range ds.children {
476485
c, err := ds.getChild(ctx, idx)
477486
if err != nil {
478487
return err
479488
}
480489

481-
switch c := c.(type) {
482-
case *shardValue:
490+
childType := c.nodeType()
491+
492+
switch childType {
493+
case shardValueNode:
483494
if err := cb(c); err != nil {
484495
return err
485496
}
486497

487-
case *Shard:
498+
case shardNode:
488499
if err := c.walkTrie(ctx, cb); err != nil {
489500
return err
490501
}
@@ -509,8 +520,10 @@ func (ds *Shard) modifyValue(ctx context.Context, hv *hashBits, key string, val
509520
return err
510521
}
511522

512-
switch child := child.(type) {
513-
case *Shard:
523+
childType := child.nodeType()
524+
525+
switch childType {
526+
case shardNode:
514527
err := child.modifyValue(ctx, hv, key, val)
515528
if err != nil {
516529
return err
@@ -526,17 +539,15 @@ func (ds *Shard) modifyValue(ctx context.Context, hv *hashBits, key string, val
526539
ds.bitfield.UnsetBit(idx)
527540
return ds.rmChild(cindex)
528541
case 1:
529-
nchild, ok := child.children[0].(*shardValue)
530-
if ok {
531-
// sub-shard with a single value element, collapse it
532-
ds.setChild(cindex, nchild)
533-
}
542+
nchild := child.children[0]
543+
// sub-shard with a single value element, collapse it
544+
ds.setChild(cindex, nchild)
534545
return nil
535546
}
536547
}
537548

538549
return nil
539-
case *shardValue:
550+
case shardValueNode:
540551
if child.key == key {
541552
// value modification
542553
if val == nil {

0 commit comments

Comments
 (0)