Skip to content

mwaurac/ASTen

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ASTen (A Small Tensor Library)

ASTen is a small, educational tensor library inspired by PyTorch. It is designed to help users understand the internal workings of the Pytorch framework by providing a simplified implementation of core components.

Inspiration

  • After reading the Pytorch internals by ezyang, I was motivated to build a similar library. I wanted something simple but a reflection of Pytorch itself.

Current Features

  • Tensor: A multi-dimensional array object.
  • View: Create a new tensor that is a view of an existing tensor.
  • Reshape View that works on non-contiguous tensors.

Planned Features

  • Core NN ops and kernels: ReLU, softmax kernels, and cross-entropy.
  • Automatic differentiation and training primitives.
  • nn building blocks: nn.Module, nn.Parameter, and nn.Embedding.
  • Attention and transformer-focused kernels, including flash attention.
  • Convolution operators: 2D conv and 3D conv.
  • Expanded CUDA support across operators and kernels.

...by the end of this project, I want to train gpt2 using this library only.

Project Structure

The project is organized into the following directories, mimicking a simplified PyTorch structure:

.
├── ASTen/
│   ├── csrc/              # C++/pybind11 bindings
│   ├── _tensor.py         # Python Tensor API wrapper
│   ├── _creation.py       # Python creation ops (tensor/zeros/ones)
│   └── __init__.py        # Public package API
├── aten/
│   └── native/
├── c10/
│   └── core/
├── test/
│   └── frontend/          # Python frontend tests
├── setup.py
└── README.md

Installation

You can install ASTen from the source:

pip install -e .

Run tests with:

pytest test

Usage

Here is a simple example of how to create a tensor and use the view operation:

import ASTen
import numpy as np

# Create a tensor from a numpy array
data = np.array([1, 2, 3, 4, 5, 6])
x = ASTen.tensor(data)

print(f"Original tensor shape: {x.shape}")

# Create a view of the tensor
y = x.view((2, 3))
print(f"Viewed tensor shape: {y.shape}")

Contributing

Contributions are welcome! Please feel free to submit a pull request or open an issue.

License

This project is licensed under the MIT License.

About

A Small Tensor Library implemented in C

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors