Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype Dex Implementation #17

Closed
srush opened this issue Dec 6, 2020 · 4 comments
Closed

Prototype Dex Implementation #17

srush opened this issue Dec 6, 2020 · 4 comments

Comments

@srush
Copy link
Contributor

srush commented Dec 6, 2020

@apaszke convinced me that Dex can do named tensors even in its current form and provide type checking. It's pretty close. Here's a prototype that implements the current attention formulation using named tensors:

def attention (q: {heads:h & seq2:s2 & key: kt} => Float)
              (k: {heads:h & seq:s  & key: kt} => Float)
              (v: {heads:h & seq:s  & val: vt} => Float) : 
         ({ heads : h & seq2:s2 & val: vt} => Float) =
     q2 = ndim #seq q
     k2 = ndim #seq2 k
     v2 = ndim #seq2 v
     inner = (nfun #seq softmax) (ndot #key q2 k2)
     ndot #seq v2 (ndim #val inner)

The Dex formulation views names as record index types. It automatically generates functions of the form #seq that act as lenses for accessing these forms.

Their record syntax also lets you do things in roughly the same syntax we have been using. If I want to sum out heads:

def indexsum (q: {heads:h & seq2:s2 & key: kt} => Float) :                           
         ({seq2:s2 & key: kt} => Float) =                                            
    sum for i:h. q.{heads=i}

or alternatively

def indexsum (q: {heads:h & seq2:s2 & key: kt} => Float) : 
         ({seq2:s2 & key: kt} => Float) =
    (nred #heads sum)  q

The only thing I am stuck on (maybe @apaske knows the answer?) is whether this can do broadcasting? My current implementation manually expands extra dimensions through a ndim argument in order to line up the record types. Is there a nice way to get the union between two records automatically? (Particularly is there a version of ndot below where the as can be different types)

Here's my full implementation if you are interested. My implementation is very similar to @davidweichiang 's named and numbered style. pop pull out a vector dim and push puts it back.

def rename (name1: Iso a (b & c)) (name2: Iso d (b & c))
    (tensor: a => Float) : (d => Float) =

    for i: d.
        value = getAt name2 i
        old = popAt name2 i
        new = pushAt name1 value old
        tensor.new

def ndot (name: Iso a (b & c))
         (tensor1: a => Float )
         (tensor2: a => Float )
  : c => Float =
  for i : c.
     sum for j:b.
        newindex = pushAt name j i 
        tensor1.newindex  * tensor2.newindex

def push (name: Iso a (b & c))
         (tensor1: b => c => Float )
  : a => Float =
  for i : a.
      index1 = getAt name i
      index2 = popAt name i 
      tensor1.index1.index2

def pop (name: Iso a (b & c))
        (tensor1: a => Float )
  : b => c => Float =
  for i : b.
    for j : c.
      index = pushAt name i j
      tensor1.index 

def nred (name: Iso a (b & c))
         (fn : b => Float -> Float) :
         (a => Float -> c => Float) =
   \tensor.
     t2 = pop name tensor
     for j: c. fn (transpose t2).j

def nfun (name: Iso a (b & c))
         (fn : b => Float -> b => Float) :
         (a => Float -> a => Float) =
   \tensor.
     t2 = pop name tensor
     push name (transpose (for j: c. fn (transpose t2).j))
     
def ndim (name: Iso a (b & c)) 
         (tensor: c => Float) : (a => Float) =
    push name for i: b. tensor
@boazbk
Copy link
Contributor

boazbk commented Dec 6, 2020

Nice! Will take a look (I do need to brush up on my non-existing Haskell :) )

I was also trying to think of this in terms of code as well, wrote my initial thoughts on https://hackmd.io/@boazbk/HyUg4D9iw

@srush
Copy link
Contributor Author

srush commented Dec 6, 2020

Neat, I'll take a look.

The dex style is interesting. They really do treat indexing fully by record types, similar to the v1 proposal. There is no named dimension type (DID), simply a mapping from a name to a standard finite dimension type.

So the default would be:

for w in range(W): 
    for h in range(H):
        print(A[{width:w, height:h}])

You can do alternatively do:

for index in indexset({width: W, height:H}):
    print(A[index])

But as far as I can tell there is nothing in-between, i.e. this would not work without more explicit transformations.

for index in indexset({width: W}):
    print(A[index])

Although if that's the style we arrive at, I'm sure we could make it work.

@srush
Copy link
Contributor Author

srush commented Dec 8, 2020

I am going to close this as I think the type system of Dex is different enough from what we are building that it would be hard to bridge the gap. Dex is neat, Named Tensors is neat, but they are different beasts.

@srush srush closed this as completed Dec 8, 2020
@oxinabox
Copy link

oxinabox commented Dec 9, 2020

Have you seen the ideal of Existential Dimensions?
I got this idea 2rd hand via @jekbradbury from @dougalm, and idk on its current status for being able to do it in Dex.
invenia/NamedDims.jl#61
but i think it would be execelent to be able to do.

It solve the fact that various operations that should return a named dimension don't know what that name sure be.
Like a multiply between a unnamed tensor and a named tensor gets one dimension with an unknown name, lets call that a existential name.
Another example is the latent dimension from a matrix factorization which gives you two existential names that must be equal to each other on different arrays.
But if you add two tensors, one fully named (Call it publicly named) and one with an existential name then now we know that that existential name must be equal to the that public name.
So then you can do a kind of type -inference to propagate that name to every other existentially name that has to be the same as this one.
and then if you end up while doing this trying to assign two different public names to the same existential name, then you throw an error as someone has done something invalid.

And then there is a fun extension for doing this with namespaces so you can have one public name per namespace.
which i think if done write can let you deal with the fact that one library might call observations :obs, and another call them :times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants