Add GF16 Split/Join (#194)

* Add GF16 Split/Join

Also check if we have enough shards when reconstructing.
master
Klaus Post 2022-07-26 09:14:03 -07:00 committed by GitHub
parent 3a82d28edb
commit 77188e96d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 267 additions and 2 deletions

View File

@ -207,7 +207,43 @@ func (r *reedSolomonFF16) EncodeIdx(dataShard []byte, idx int, parity [][]byte)
}
func (r *reedSolomonFF16) Join(dst io.Writer, shards [][]byte, outSize int) error {
return errors.New("not implemented")
// Do we have enough shards?
if len(shards) < r.DataShards {
return ErrTooFewShards
}
shards = shards[:r.DataShards]
// Do we have enough data?
size := 0
for _, shard := range shards {
if shard == nil {
return ErrReconstructRequired
}
size += len(shard)
// Do we have enough data already?
if size >= outSize {
break
}
}
if size < outSize {
return ErrShortData
}
// Copy data to dst
write := outSize
for _, shard := range shards {
if write < len(shard) {
_, err := dst.Write(shard[:write])
return err
}
n, err := dst.Write(shard)
if err != nil {
return err
}
write -= n
}
return nil
}
func (r *reedSolomonFF16) Update(shards [][]byte, newDatashards [][]byte) error {
@ -215,7 +251,46 @@ func (r *reedSolomonFF16) Update(shards [][]byte, newDatashards [][]byte) error
}
func (r *reedSolomonFF16) Split(data []byte) ([][]byte, error) {
return nil, errors.New("not implemented")
if len(data) == 0 {
return nil, ErrShortData
}
dataLen := len(data)
// Calculate number of bytes per data shard.
perShard := (len(data) + r.DataShards - 1) / r.DataShards
perShard = ((perShard + 63) / 64) * 64
if cap(data) > len(data) {
data = data[:cap(data)]
}
// Only allocate memory if necessary
var padding []byte
if len(data) < (r.Shards * perShard) {
// calculate maximum number of full shards in `data` slice
fullShards := len(data) / perShard
padding = make([]byte, r.Shards*perShard-perShard*fullShards)
copy(padding, data[perShard*fullShards:])
data = data[0 : perShard*fullShards]
} else {
for i := dataLen; i < dataLen+r.DataShards; i++ {
data[i] = 0
}
}
// Split into equal-length shards.
dst := make([][]byte, r.Shards)
i := 0
for ; i < len(dst) && len(data) >= perShard; i++ {
dst[i] = data[:perShard:perShard]
data = data[perShard:]
}
for j := 0; i+j < len(dst); j++ {
dst[i+j] = padding[:perShard:perShard]
padding = padding[perShard:]
}
return dst, nil
}
func (r *reedSolomonFF16) ReconstructSome(shards [][]byte, required []bool) error {
@ -267,6 +342,29 @@ func (r *reedSolomonFF16) reconstruct(shards [][]byte, recoverAll bool) error {
return err
}
// Quick check: are all of the shards present? If so, there's
// nothing to do.
numberPresent := 0
dataPresent := 0
for i := 0; i < r.Shards; i++ {
if len(shards[i]) != 0 {
numberPresent++
if i < r.DataShards {
dataPresent++
}
}
}
if numberPresent == r.Shards || !recoverAll && dataPresent == r.DataShards {
// Cool. All of the shards data data. We don't
// need to do anything.
return nil
}
// Check if we have enough to reconstruct.
if numberPresent < r.DataShards {
return ErrTooFewShards
}
shardSize := shardSize(shards)
if shardSize%64 != 0 {
return ErrShardSize

167
leopard_test.go Normal file
View File

@ -0,0 +1,167 @@
package reedsolomon
import (
"bytes"
"math/rand"
"testing"
)
func TestEncoderReconstructLeo(t *testing.T) {
testEncoderReconstructLeo(t)
}
func testEncoderReconstructLeo(t *testing.T, o ...Option) {
// Create some sample data
var data = make([]byte, 2<<20)
fillRandom(data)
// Create 5 data slices of 50000 elements each
enc, err := New(500, 300, testOptions(o...)...)
if err != nil {
t.Fatal(err)
}
shards, err := enc.Split(data)
if err != nil {
t.Fatal(err)
}
err = enc.Encode(shards)
if err != nil {
t.Fatal(err)
}
// Check that it verifies
ok, err := enc.Verify(shards)
if !ok || err != nil {
t.Fatal("not ok:", ok, "err:", err)
}
// Delete a shard
shards[0] = nil
// Should reconstruct
err = enc.Reconstruct(shards)
if err != nil {
t.Fatal(err)
}
// Check that it verifies
ok, err = enc.Verify(shards)
if !ok || err != nil {
t.Fatal("not ok:", ok, "err:", err)
}
// Recover original bytes
buf := new(bytes.Buffer)
err = enc.Join(buf, shards, len(data))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), data) {
t.Fatal("recovered bytes do not match")
}
// Corrupt a shard
shards[0] = nil
shards[1][0], shards[1][500] = 75, 75
// Should reconstruct (but with corrupted data)
err = enc.Reconstruct(shards)
if err != nil {
t.Fatal(err)
}
// Check that it verifies
ok, err = enc.Verify(shards)
if ok || err != nil {
t.Fatal("error or ok:", ok, "err:", err)
}
// Recovered data should not match original
buf.Reset()
err = enc.Join(buf, shards, len(data))
if err != nil {
t.Fatal(err)
}
if bytes.Equal(buf.Bytes(), data) {
t.Fatal("corrupted data matches original")
}
}
func TestEncoderReconstructFailLeo(t *testing.T) {
// Create some sample data
var data = make([]byte, 2<<20)
fillRandom(data)
// Create 5 data slices of 50000 elements each
enc, err := New(500, 300, testOptions()...)
if err != nil {
t.Fatal(err)
}
shards, err := enc.Split(data)
if err != nil {
t.Fatal(err)
}
err = enc.Encode(shards)
if err != nil {
t.Fatal(err)
}
// Check that it verifies
ok, err := enc.Verify(shards)
if !ok || err != nil {
t.Fatal("not ok:", ok, "err:", err)
}
// Delete more than parity shards
for i := 0; i < 301; i++ {
shards[i] = nil
}
// Should not reconstruct
err = enc.Reconstruct(shards)
if err != ErrTooFewShards {
t.Fatal("want ErrTooFewShards, got:", err)
}
}
func TestSplitJoinLeo(t *testing.T) {
var data = make([]byte, (250<<10)-1)
rand.Seed(0)
fillRandom(data)
enc, _ := New(500, 300, testOptions()...)
shards, err := enc.Split(data)
if err != nil {
t.Fatal(err)
}
_, err = enc.Split([]byte{})
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}
buf := new(bytes.Buffer)
err = enc.Join(buf, shards, 5000)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), data[:5000]) {
t.Fatal("recovered data does match original")
}
err = enc.Join(buf, [][]byte{}, 0)
if err != ErrTooFewShards {
t.Errorf("expected %v, got %v", ErrTooFewShards, err)
}
err = enc.Join(buf, shards, len(data)+500*64)
if err != ErrShortData {
t.Errorf("expected %v, got %v", ErrShortData, err)
}
shards[0] = nil
err = enc.Join(buf, shards, len(data))
if err != ErrReconstructRequired {
t.Errorf("expected %v, got %v", ErrReconstructRequired, err)
}
}