Skip to content

Commit

Permalink
WIDE
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 9, 2024
1 parent a1f8dec commit b7fec90
Show file tree
Hide file tree
Showing 26 changed files with 1,776 additions and 151 deletions.
119 changes: 117 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh)

**Contents**
- [Usage](#usage)
- [Basic usage](#basic-usage)
- [LSH](#lsh)
- [Packing wide data](#packing-wide-data)
- [Options](#options)
- [Benchmarks](#benchmarks)
- [License](#license)

## Usage

### Basic usage

```go
package main

import (
"fmt"

"github.com/keilerkonzept/bitknn"
)

Expand All @@ -45,12 +51,121 @@ func main() {
votes := make([]float64, 2)

k := 2
model.Predict1(k, 0b101011, votes)
model.Predict1(k, 0b101011, bitknn.VoteSlice(votes))

fmt.Println("Votes:", bitknn.VoteSlice(votes))

// you can also use a map for the votes.
// this is good if you have a very large number of different labels:
votesMap := make(map[int]float64)
model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap))
fmt.Println("Votes for 0:", votesMap[0])
}
```

### LSH

Locality-Sensitive Hashing (LSH) is a type of approximate k-NN search. It's faster at the expense of accuracy.

LSH works by hashing data points such that points that are close in Hamming space tend to land in the same bucket, and computing k-nearest neighbors only on the buckets with the k nearest hashes. In particular, for *k*=1 only one bucket needs to be examined.

```go
package main

import (
"fmt"
"github.com/keilerkonzept/bitknn/lsh"
"github.com/keilerkonzept/bitknn"
)

func main() {
// feature vectors packed into uint64s
data := []uint64{0b101010, 0b111000, 0b000111}
// class labels
labels := []int{0, 1, 1}

// Define a hash function (e.g., MinHash)
hash := lsh.RandomMinHash()

// Fit an LSH model
model := lsh.Fit(data, labels, hash, bitknn.WithLinearDistanceWeighting())

// one vote counter per class
votes := make([]float64, 2)

k := 2
model.Predict1(k, 0b101011, bitknn.VoteSlice(votes))

fmt.Println("Votes:", bitknn.VoteSlice(votes))

// you can also use a map for the votes
votesMap := make(map[int]float64)
model.Predict1(k, 0b101011, bitknn.VoteMap(votesMap))
fmt.Println("Votes for 0:", votesMap[0])
}
```

The model accepts anything that implements the [`lsh.Hash` interface](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Hash) as a hash function. Several functions are pre-defined:

- [MinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): An implementation of the [MinHash scheme](https://en.m.wikipedia.org/wiki/MinHash) for bit vectors.

Constructors: [RandomMinHash](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHash), [RandomMinHashR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashR).
- [MinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHash): Concatenation of several *MinHash*es.

Constructors: [RandomMinHashes](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashes), [RandomMinHashesR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomMinHashesR).
- [Blur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Blur): A threshold-based variation on bit sampling.

Constructors: [RandomBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlur), [RandomBlurR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBlurR), [BoxBlur](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BoxBlur), .
- [BitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSample): A random sampling of bits from the feature vector.

Constructors: [RandomBitSample](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSample), [RandomBitSampleR](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#RandomBitSampleR).

For datasets of vectors longer than 64 bits, the `lsh` package also provides a [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) function, and "wide" versions of the hash functions ([MinHashWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#MinHashWide), [BlurWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BlurWide), [BitSampleWide](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#BitSampleWide))

The [`lsh.Fit`/`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit) functions accept the same [Options](#options) as the others.

### Packing wide data

If your vectors are longer than 64 bits, you can still use `bitknn` if you [pack](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) them into `[]uint64`. The [`pack` package](https://pkg.go.dev/github.com/keilerkonzept/bitknn/pack) defines helper functions to pack `string`s and `[]byte`s into `[]uint64`s.

The exact k-NN model in `bitknn` and the approximate-NN model in `lsh` each have a `Wide` variant that accepts slice-valued data points:

```go
package main

import (
"fmt"

"github.com/keilerkonzept/bitknn"
"github.com/keilerkonzept/bitknn/pack"
)

func main() {
// feature vectors packed into uint64s
data := [][]uint64{
pack.String("foo"),
pack.String("bar"),
pack.String("baz"),
}
// class labels
labels := []int{0, 1, 1}

// model := lsh.FitWide(data, labels, lsh.RandomMinHash(), bitknn.WithLinearDistanceWeighting())
model := bitknn.FitWide(data, labels, bitknn.WithLinearDistanceWeighting())

// one vote counter per class
votes := make([]float64, 2)

k := 2
query := pack.String("fob")
model.Predict1(k, query, bitknn.VoteSlice(votes))

fmt.Println("Votes:", votes)
fmt.Println("Votes:", bitknn.VoteSlice(votes))
}
```

The wide model fitting function [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) accepts the same [Options](#options) as the "narrow" one.

## Options

- `WithLinearDistanceWeighting()`: Apply linear distance weighting (`1 / (1 + dist)`).
Expand Down
15 changes: 15 additions & 0 deletions internal/testrandom/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ func Query() uint64 {
return Source.Uint64()
}

func WideQuery(dim int) []uint64 {
return Data(dim)
}

func Data(size int) []uint64 {
data := make([]uint64, size)
for i := range data {
Expand All @@ -16,6 +20,17 @@ func Data(size int) []uint64 {
return data
}

func WideData(dim int, size int) [][]uint64 {
data := make([][]uint64, size)
for i := range data {
data[i] = make([]uint64, dim)
for j := range dim {
data[i][j] = Source.Uint64()
}
}
return data
}

func Labels(size int) []int {
labels := make([]int, size)
for i := range labels {
Expand Down
19 changes: 19 additions & 0 deletions internal/testrandom/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,31 @@ import (
func TestQuery(t *testing.T) {
_ = testrandom.Query()
}

func TestWideQuery(t *testing.T) {
q := testrandom.WideQuery(5)
if len(q) != 5 {
t.Fatal()
}
}

func TestData(t *testing.T) {
data := testrandom.Data(123)
if len(data) != 123 {
t.Fatal()
}
}

func TestWideData(t *testing.T) {
data := testrandom.WideData(3, 123)
if len(data) != 123 {
t.Fatal()
}
if len(data[0]) != 3 {
t.Fatal()
}
}

func TestLabels(t *testing.T) {
data := testrandom.Labels(123)
if len(data) != 123 {
Expand Down
Loading

0 comments on commit b7fec90

Please sign in to comment.