Stannum
Fusing Taichi into PyTorch
Why Stannum?
In differentiable rendering including neural rendering, rendering algorithms are transferred to the field of computer vision, but some rendering operations (e.g., ray tracing and direct volume rendering) are not easy to be expressed in tensor operations but in kernels. Differentiable kernels of Taichi enables fast, efficient and differentiable implementation of rendering algorithms while tensor operators provides math expressiveness.
Stannum bridges Taichi and PyTorch to have advantage of both kernel-based and operator-based parallelism.
Installation
Install stannum
with pip
by
python -m pip install stannum
Make sure you have the following installed:
- PyTorch
- latest Taichi
- For performance concerns, we strongly recommend to use Taichi >= 1.1.3 (see Issue #9 for more information)
Differentiability
Stannum does NOT check the differentiability of your kernels, so you may not get correct gradients if your kernel is not differentiable. Please refer to Differentiable Programming of Taichi for more information.
Tin
or Tube
?
stannum
mainly has two high-level APIs, Tin
and Tube
. Tin
aims to be the thinnest bridge layer with the least overhead while Tube
has more functionalities and convenience with some more overhead.
See the comparison below:
Tin /EmptyTin | Tube | |
---|---|---|
Overhead1 2 | Low❤️ | A bit more overhead due to auto memory management |
Field Management | Users must manage Taichi fields⚠️ | Auto management♻️ |
Forward Pass Bridging | ✅ | ✅ |
Backward Pass Gradient Bridging | ✅ | ✅ |
Batching | ❌ | ✅ |
Variable Tensor Shapes | ❌ | ✅ |
(Performance Tip) A lot of assertions in stannum
make sure you do the right thing or get a right error when you do it wrong, which is helpful in debugging but incurs a bit overhead. To get rid of assertion overhead, pass -O
to Python as suggested in the Python doc about assertions.
See Issue #9 for more information about the performance if you want to use Tube
with legacy Taichi < 1.1.3
.
Bugs & Issues
Please feel free to file issues on Github. If a runtime error occurs from the dependencies of stannum
, you may also want to check the upstream breaking change tracker.
Tin
Tin
and EmptyTin
are thinest layers that bridge Taichi and PyTorch.
Usage
To use them, you will need:
- a Taichi kernel or a Taichi data-oriented-class instance
- Taichi fields
- and registration as shown below
from stannum import Tin
import torch
data_oriented = TiClass() # some Taichi data-oriented class
device = torch.device("cpu")
kernel_args = (1.0,)
tin_layer = Tin(data_oriented, device=device, auto_clear_grad=True)
.register_kernel(data_oriented.forward_kernel, *kernel_args, kernel_name="forward") # on old Taichi
# .register_kernel(data_oriented.forward_kernel, *kernel_args) # on new Taichi
.register_input_field(data_oriented.input_field)
.register_output_field(data_oriented.output_field)
.register_internal_field(data_oriented.weight_field, name="field name")
.finish() # finish() is required to finish construction
output = tin_layer(input_tensor)
It is NOT necessary to have a @ti.data_oriented
class as long as you correctly register the fields that your
kernel needs for forward and backward calculation.
Please use EmptyTin
in this case, for example:
from stannum import EmptyTin
import torch
import taichi as ti
input_field = ti.field(ti.f32)
output_field = ti.field(ti.f32)
internal_field = ti.field(ti.f32)
@ti.kernel
def some_kernel(bias: float):
output_field[None] = input_field[None] + internal_field[None] + bias
device = torch.device("cpu")
kernel_args = (1.0,)
tin_layer = EmptyTin(device, True)\
.register_kernel(some_kernel, *kernel_args)\
.register_input_field(input_field)\
.register_output_field(output_field)\
.register_internal_field(internal_field, name="field name")\
.finish() # finish() is required to finish construction
output = tin_layer(input_tensor)
Restrictions & Warning
For input and output, the restrictions are:
- We can register multiple
input_field
,output_field
,internal_field
. - At least one
input_field
and oneoutput_field
should be registered. - The order of input tensors must match the registration order of
input_field
s. - The output order will align with the registration order of
output_field
s. - Kernel args must be acceptable by Taichi kernels and they will not get gradients.
Be warned that it is YOUR responsibility to create and manage fields and if you don’t manage fields properly, memory leaks and read-after-free of fields can happen.
APIs
Constructors
Tin
:
def __init__(self,
data_oriented: Any,
device: torch.device,
auto_clear_grad: bool,
_auto_clear: bool = need_auto_clearing_fields):
"""
Init a Tin instance
@param data_oriented: @ti.data_oriented class instance
@param device: torch.device instance
@param auto_clear_grad: auto clear gradients in fields before backward computation
@param _auto_clear: clear fields before use
"""
EmptyTin
:
def __init__(self,
device: torch.device,
auto_clear_grad: bool,
_auto_clear: bool = need_auto_clearing_fields):
"""
Init an EmptyTin instance
@param device: torch.device instance
@param auto_clear_grad: auto clear gradients in fields before backward computation
@param _auto_clear: clear fields before use, for legacy Taichi
"""
If _auto_clear
is True
, then all the registered fields will be cleared before running the kernel(s), which prevents some undefined behaviors due to un-initialized memory of fields before Taichi 0.9.1
. After Taichi 0.9.1
, the memory of fields is automatically cleared after creation, so auto_clear
is not necessary anymore. But it is still configurable if desired.
Registrations
Register Kernels
EmptyTin
:
def register_kernel(self, kernel: Callable, *kernel_args: Any, kernel_name: Optional[str] = None):
"""
Register a kernel for forward calculation
@param kernel: Taichi kernel
@param kernel_args: arguments for the kernel
@param kernel_name: kernel name, optional for new Taichi, compulsory for old Taichi
@return: self
"""
Tin
:
def register_kernel(self, kernel: Union[Callable, str], *kernel_args: Any, kernel_name: Optional[str] = None):
"""
Register a kernel for forward calculation
@param kernel: kernel function or kernel name
@param kernel_args: args for the kernel, optional
@param kernel_name: kernel name, optional for new Taichi, compulsory for old Taichi
@return: self
"""
Register Fields
Register input fields:
def register_input_field(self, field: Union[ScalarField, MatrixField],
name: Optional[str] = None,
needs_grad: Optional[bool] = None,
complex_dtype: bool = False):
"""
Register an input field which requires a tensor input in the forward calculation
@param field: Taichi field
@param name: name of this field, default: "input_field_ith"
@param needs_grad: whether the field needs grad, `None` for automatic configuration
@param complex_dtype: whether the input tensor that is going to be filled into this field is complex numbers
@return: self
"""
Register internal fields that are used to store intermediate values if multiple kernels are used:
def register_internal_field(self, field: Union[ScalarField, MatrixField],
needs_grad: Optional[bool] = None,
name: Optional[str] = None,
value: Optional[torch.Tensor] = None,
complex_dtype: bool = False):
"""
Register a field that serves as weights internally and whose values are required by the kernel function
@param field: Taichi field
@param needs_grad: whether the field needs grad, `None` for automatic configuration
@param name: name for the field, facilitating later value setting, `None` for default number naming
@param value: optional initial values from a tensor
@param complex_dtype: whether the input tensor that is going to be filled into this field is complex numbers
@return: self
"""
Register output fields:
def register_output_field(self, field: Union[ScalarField, MatrixField],
name: Optional[str] = None,
needs_grad: Optional[bool] = None,
complex_dtype: bool = False):
"""
Register an output field that backs an output tensor in the forward calculation
@param field: Taichi field
@param name: name of this field, default: "output_field_ith"
@param needs_grad: whether the field needs grad, `None` for automatic configuration
@param complex_dtype: whether the input tensor that is going to be filled into this field is complex numbers
@return: self
"""
Setters
Internal fields
The values of internal fields can be set by:
def set_internal_field(self, field_name: Union[str, int], tensor: torch.Tensor):
"""
Sets the value of an internal field from a tensor
@param field_name: integer(when using default number naming) or string name
@param tensor: values for the field
@return: None
"""
Kernel Arguments
Kernels may need arguments that do not need gradients, then you can set extra arguments with Tin/EmptyTin.set_kernelargs()
or set extra arguments in Tin/EmptyTin.register_kernel()
def set_kernel_args(self, kernel: Union[Callable, str], *kernel_args: Any):
"""
Set args for a kernel
@param kernel: kernel function or its name
@param kernel_args: kernel arguments
"""
One example is shown below. Note that the kernel has already contains references to fields, which differs from the case in Tube
.
input_field = ti.field(ti.f32)
output_field = ti.field(ti.f32)
internal_field = ti.field(ti.f32)
@ti.kernel
def some_kernel(adder: float):
output_field[None] = input_field[None] + internal_field[None] + adder
Tube
Tube
, compared to Tin
, helps you:
- create necessary fields
- manage fields
- (optionally) do automatic batching
- (optionally) create fields and tensors of which shapes are dynamically calculated. For example, the case in convolution.
So, Tube
is more flexible and convenient, but it also introduces some overhead.
Usage
All you need to do is to register:
- Input/intermediate/output tensor shapes instead of fields
- At least one kernel that takes the following as arguments
- Taichi fields: correspond to tensors (may or may not require gradients)
- (Optional) Extra arguments: will NOT receive gradients
- see below Set Kernel Arguments
- (Optional) Tweak field creation, providing a
FieldManager
, see Advanced Field Construction - (Optional) Dynamically calculate the shape of fields and tensors, providing a
DimensionCalculator
, see Dynamic Dimension Calculation
Requirements
Registration order: Input tensors/intermediate fields/output tensors must be registered first, and then kernel.
When registering a kernel, a list of field/tensor names is required, for example, the above ["arr_a", "arr_b", "output_arr"]
.
This list should correspond to the fields in the arguments of a kernel (e.g., below ti_add()
).
The order of input tensors should match the input fields of a kernel.
A valid example is shown below:
@ti.kernel
def ti_add(arr_a: ti.template(), arr_b: ti.template(), output_arr: ti.template()):
for i in arr_a:
output_arr[i] = arr_a[i] + arr_b[i]
ti.init(ti.cpu)
cpu = torch.device("cpu")
a = torch.ones(10)
b = torch.ones(10)
tube = Tube(cpu) \
.register_input_tensor((10,), torch.float32, "arr_a", False) \
.register_input_tensor((10,), torch.float32, "arr_b", False) \
.register_output_tensor((10,), torch.float32, "output_arr", False) \
.register_kernel(ti_add, ["arr_a", "arr_b", "output_arr"]) \
.finish()
out = tube(a, b)
Acceptable dimensions of tensors to be registered are:
stannum.BatchDim
: means the flexible batch dimension, must be the first dimension e.g.(BatchDim, 2, 3, 4)
- Positive integers: fixed dimensions with the indicated dimensionality
stannum.AnyDim
: means any number[1, +inf)
, only usable in the registration of input tensors.stannum.MatchDim(dim_id: str | int)
: means some dimensions with the samedim_id
must be of the same dimensionality- Restriction:
MatchDim
must be "declared" in the registration of input tensors first, then used in the registration of intermediate and output tensors. - Example: tensor
a
andb
of shapesa: (2, MatchDim("some_dim"), 3)
andb: (MatchDim("some_dim"), 5, 6)
mean the dimensions ofsome_dim
must match, that is, the second dimension ofa
must match the first dimension ofb
.
- Restriction:
Automatic Batching
Automatic batching is done simply by running kernels batch
times. The batch number is determined by the leading dimension of tensors of registered shape (None, ...)
.
It's required that if any input tensors are batched (which means they have registered the first dimension to be None
), all intermediate fields and output tensors must be registered as batched.
More Examples
Simple one without negative indices or batch dimension:
@ti.kernel
def ti_add(arr_a: ti.template(), arr_b: ti.template(), output_arr: ti.template()):
for i in arr_a:
output_arr[i] = arr_a[i] + arr_b[i]
ti.init(ti.cpu)
cpu = torch.device("cpu")
a = torch.ones(10)
b = torch.ones(10)
tube = Tube(cpu) \
.register_input_tensor((10,), torch.float32, "arr_a", False) \
.register_input_tensor((10,), torch.float32, "arr_b", False) \
.register_output_tensor((10,), torch.float32, "output_arr", False) \
.register_kernel(ti_add, ["arr_a", "arr_b", "output_arr"]) \
.finish()
out = tube(a, b)
With dimension matching:
ti.init(ti.cpu)
cpu = torch.device("cpu")
tube = Tube(cpu) \
.register_input_tensor((MatchDim(0),), torch.float32, "arr_a", False) \
.register_input_tensor((MatchDim(0),), torch.float32, "arr_b", False) \
.register_output_tensor((MatchDim(0),), torch.float32, "output_arr", False) \
.register_kernel(ti_add, ["arr_a", "arr_b", "output_arr"]) \
.finish()
dim = 10
a = torch.ones(dim)
b = torch.ones(dim)
out = tube(a, b)
assert torch.allclose(out, torch.full((dim,), 2.))
dim = 100
a = torch.ones(dim)
b = torch.ones(dim)
out = tube(a, b)
assert torch.allclose(out, torch.full((dim,), 2.))
With batch dimension:
@ti.kernel
def int_add(a: ti.template(), b: ti.template(), out: ti.template()):
out[None] = a[None] + b[None]
ti.init(ti.cpu)
b = torch.tensor(1., requires_grad=True)
batched_a = torch.ones(10, requires_grad=True)
tube = Tube() \
.register_input_tensor((BatchDim,), torch.float32, "a") \
.register_input_tensor((), torch.float32, "b") \
.register_output_tensor((BatchDim,), torch.float32, "out", True) \
.register_kernel(int_add, ["a", "b", "out"]) \
.finish()
out = tube(batched_a, b)
loss = out.sum()
loss.backward()
assert allclose(torch.ones_like(batched_a) + 1, out)
assert b.grad == 10.
assert allclose(torch.ones_like(batched_a), batched_a.grad)
For more valid and invalid use examples, please see test files in the test folder.
APIs
Constructor
def __init__(self,
device: Optional[torch.device] = None,
persistent_field: bool = True,
enable_backward: bool = True):
"""
Init a tube
@param device: Optional, torch.device tensors are on, if it's None, the device is determined by input tensors
@param persistent_field: whether or not to save fields during forward pass.
If True, created fields will not be destroyed until compute graph is cleaned,
otherwise they will be destroyed right after forward pass is done and re-created in backward pass.
Having two modes is due to Taichi's performance issue, see https://github.com/taichi-dev/taichi/pull/4356
@param enable_backward: whether or not to enable backward gradient computation, disable it will have performance
improvement in forward pass, but attempting to do backward computation will cause runtime error.
"""
Registrations
Register input tensor shapes:
def register_input_tensor(self,
dims: Tuple[DimOption] | List[DimOption],
dtype: torch.dtype,
name: str,
requires_grad: Optional[bool] = None,
field_manager: Optional[FieldManager] = None):
"""
Register an input tensor
@param dims: dims can contain `None`, positive and negative numbers,
for restrictions and requirements, see README
@param dtype: torch data type
@param name: name of the tensor and corresponding field
@param requires_grad: optional, if it's None, it will be determined by input tensor
@param field_manager: customized field manager, if it's None, a DefaultFieldManger will be used
"""
Register intermediate field shapes:
def register_intermediate_field(self,
dims_or_calc: Tuple[DimOption] | List[DimOption] | Callable | DimensionCalculator,
ti_dtype: TiDataType,
name: str,
needs_grad: bool,
field_manager: Optional[FieldManager] = None):
"""
Register an intermediate field,
which can be useful if multiple kernels are used and intermediate results between kernels are stored
@param dims_or_calc: dims can contain `None`, positive and negative numbers,
for restrictions and requirements, see README; Or DimensionCalculator instance or a function can be passed
to dynamically calculate dimensions
@param ti_dtype: taichi data type
@param name: name of the field
@param needs_grad: if the field needs gradients.
@param dims: dims can contain `None`, positive and negative numbers,
for restrictions and requirements, see README
@param dim_calc: DimensionCalculator instance or a function
@param field_manager: customized field manager, if it's None, a DefaultFieldManger will be used
"""
Register output tensor shapes:
def register_output_tensor(self,
dims_or_calc: Tuple[DimOption] | List[DimOption] | Callable | DimensionCalculator,
dtype: torch.dtype,
name: str,
requires_grad: bool,
field_manager: Optional[FieldManager] = None):
"""
Register an output tensor
@param dims_or_calc: dims can contain `None`, positive and negative numbers,
for restrictions and requirements, see README; Or DimensionCalculator instance or a function can be passed
to dynamically calculate dimensions
@param dtype: torch data type
@param name: name of the tensor and corresponding field
@param dims: dims can contain `None`, positive and negative numbers,
for restrictions and requirements, see README
@param dim_calc: DimensionCalculator instance or a function
@param requires_grad: if the output requires gradients
@param field_manager: customized field manager, if it's None, a DefaultFieldManger will be used
"""
Register kernels:
def register_kernel(self, kernel: Callable, tensor_names: List[str], *extra_args: Any, name: Optional[str] = None):
"""
Register a Taichi kernel
@param kernel: Taichi kernel. For requirements, see README
@param tensor_names: the names of registered tensors that are to be used in this kernel
@param extra_args: any extra arguments passed to the kernel
@param name: name of this kernel, if it's None, it will be kernel.__name__
"""
Set Kernel Extra Arguments
Kernels may need extra arguments that do not need gradients, then you can set extra arguments with Tube.set_kernel_extra_args()
or set extra arguments in Tube.register_kernel()
def set_kernel_extra_args(self, kernel: Callable | str, *extra_args: Any):
"""
Set args for a kernel
@param kernel: kernel function or its name
@param extra_args: extra kernel arguments
"""
One example kernel is shown below, in which multiplier
is an extra kernel argument.
@ti.kernel
def mul(arr: ti.template(), out: ti.template(), multiplier: float):
for i in arr:
out[i] = arr[i] * multiplier
Advanced Field Construction
With FieldManager
, you can tweak how fields are constructed in order to gain performance improvement in kernel calculations.
By supplying a customized FieldManager
when registering a field, you can construct a field however you want.
WARNING:
- If you don't know why constructing fields differently can improve performance, don't use this feature.
- If you don't know how to construct fields differently, please refer to Taichi field documentation.
Example
In auxiliary.py
, FieldManager
is defined as an abstract class as
class FieldManager(ABC):
"""
FieldManagers enable potential flexible field constructions and manipulations.
For example, instead of ordinarily layout-ting a multidimensional field,
you can do hierarchical placements for fields, which may gives dramatic performance improvements
based on applications. Since hierarchical fields may not have the same shape of input tensor,
it's YOUR responsibility to write a FieldManager that can correctly transform field values into/from tensors
"""
@abstractmethod
def construct_field(self,
fields_builder: ti.FieldsBuilder,
concrete_tensor_shape: Tuple[int, ...],
needs_grad: bool) -> Union[ScalarField, MatrixField]:
pass
@abstractmethod
def to_tensor(self, field: Union[ScalarField, MatrixField]) -> torch.Tensor:
pass
@abstractmethod
def grad_to_tensor(self, grad_field: Union[ScalarField, MatrixField]) -> torch.Tensor:
pass
@abstractmethod
def from_tensor(self, field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
pass
@abstractmethod
def grad_from_tensor(self, grad_field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
pass
One example is the DefaultFieldManger
in tube.py
defined as:
class DefaultFieldManager(FieldManager):
"""
Default field manager which layouts data in tensors by constructing fields
with the ordinary multidimensional array layout
"""
def __init__(self,
dtype: TiDataType,
complex_dtype: bool,
device: torch.device):
self.dtype: TiDataType = dtype
self.complex_dtype: bool = complex_dtype
self.device: torch.device = device
def construct_field(self,
fields_builder: ti.FieldsBuilder,
concrete_tensor_shape: Tuple[int, ...],
needs_grad: bool) -> Union[ScalarField, MatrixField]:
assert not fields_builder.finalized
if self.complex_dtype:
field = ti.Vector.field(2, dtype=self.dtype, needs_grad=needs_grad)
else:
field = ti.field(self.dtype, needs_grad=needs_grad)
if needs_grad:
fields_builder \
.dense(axes(*range(len(concrete_tensor_shape))), concrete_tensor_shape) \
.place(field, field.grad)
else:
fields_builder.dense(axes(*range(len(concrete_tensor_shape))), concrete_tensor_shape).place(field)
return field
def to_tensor(self, field: Union[ScalarField, MatrixField]) -> torch.Tensor:
tensor = field.to_torch(device=self.device)
if self.complex_dtype:
tensor = torch.view_as_complex(tensor)
return tensor
def grad_to_tensor(self, grad_field: Union[ScalarField, MatrixField]) -> torch.Tensor:
tensor = grad_field.to_torch(device=self.device)
if self.complex_dtype:
tensor = torch.view_as_complex(tensor)
return tensor
def from_tensor(self, field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
if self.complex_dtype:
tensor = torch.view_as_real(tensor)
field.from_torch(tensor)
def grad_from_tensor(self, grad_field: Union[ScalarField, MatrixField], tensor: torch.Tensor):
if self.complex_dtype:
tensor = torch.view_as_real(tensor)
grad_field.from_torch(tensor)
Dynamic Dimension Calculation
We first see what “dimension” means in the context of stannum
and then see how we can leverage DimensionCalculator
to enable dynamic dimension calculation
Dimension != Shape
In stannum
, “dimension” is virtual while “shape” is concrete. In other words, dimensions can be either integers or values of the enum DimEnum
(namely AnyDim
, BatchDim
and MatchDim(dim_id)
) while shapes can only mean integers.
After one dimension is concretized, a shape is produced. For example, dimensions (AnyDim, 10)
contain both an integer and a DimEnum
. To concretized the dimensions, we need a matching tensor, say a tensor of shapes (123, 10)
, then the dimensions are concretized as (123, 10)
as AngDim
matches 123
. The same goes with dimensions that contains BatchDim
and MatchDim(dim_id)
besides some unification that is more complicated.
Sidenote for developers:
In tube.py, concretize_dimensions_and_unify() does exactly such concretization and unification!
Of course, in this context, dimensions can be also shapes, which makes dimensions a superset of shapes, which makes dimensions more powerful.
DimensionCalculator
To calculate dimensions dynamically, we provide an API, which is DimensionCalculator
, an abstract class containing only one method. So, alternatively, you can provide a closure as duck DimensionCalculator
implementation.
Wait, why do we need dimensions/shapes in the first place?
Tube
help you manage fields automatically. By “manage”, it automatically create and destroy fields. In field creation, it needs (concrete) shapes instead of (virtual) dimensions. At the mean time, we need some way to express batching dimension, matching dimension and don’t-care (any) dimension, so we have dimensions.
The specification of DimensionCalculator
is as simple as below:
class DimensionCalculator(ABC):
"""
An interface for implementing a calculator that hints Tube how to construct fields
"""
@abstractmethod
def calc_dimension(self,
field_name: str,
input_dimensions: Dict[str, Tuple[DimOption, ...]],
input_tensor_shapes: Dict[str, Tuple[int, ...]]) -> Tuple[DimOption, ...]:
"""
Calculate dimensions for a output/intermediate field
@param field_name: the name of the field for which the dimensions are calculated
@param input_dimensions: the dict mapping names of input fields to input fields
@param input_tensor_shapes: the dict mapping names of input fields to
shapes of input tensors that correspond to input fields
"""
pass
field_name
is the name of an intermediate/output field for which dimensions are calculated. input_dimensions
gives you all the information of dimensions of all input fields. input_tensor_shapes
gives you all the information of shapes of input tensors, of which names are the names of input fields.
Example: Convolution
In the case of convolution, the shapes and dimensions of a convolution kernel and an input (say an image) are needed to compute the shapes of output.
In test_dynamic_shape.py, there is a simple DimensionCalculator
for 1D convolution, namely D1ConvDC
. And you can see how to use it in the test cases.
Complex Number Support
When registering input fields and output fields, you can pass complex_dtype=True
to enable simple complex tensor input and output support. For instance, Tin(..).register_input_field(input_field, complex_dtype=True)
.
Now the complex tensor support is limited in that the representation of complex numbers is a barebone 2D vector, since Taichi has no official support on complex numbers yet.
This means although stannum
provides some facilities to deal with complex tensor input and output, you have to define and do the operations on the proxy 2D vectors yourself.
In practice, we now have these limitations:
-
The registered field with
complex_dtype=True
must be an appropriateVectorField
orScalarField
- If it's
VectorField
,n
should be2
, likev_field = ti.Vector.field(n=2, dtype=ti.f32, shape=(2, 3, 4, 5))
- If it's a
ScalarField
, the last dimension of it should be2
, likefield = ti.field(ti.f32, shape=(2,3,4,5,2))
- The above examples accept tensors of
dtype=torch.cfloat, shape=(2,3,4,5)
- If it's
-
The semantic of complex numbers is not preserved in kernels, so you are manipulating regular fields, and as a consequence, you need to implement complex number operators yourself
* Example:
@ti.kernel def element_wise_complex_mul(self): for i in self.complex_array0: # this is not complex number multiplication, but only a 2D vector element-wise multiplication self.complex_output_array[i] = self.complex_array0[i] * self.complex_array1[i]
Contribution
PRs are always welcomed, please see TODOs and issues.
TODOs
Documentation
- Improve documentation
Features
- PyTorch-related:
- PyTorch checkpoint and save model
- Proxy
torch.nn.parameter.Parameter
for weight fields for optimizers
- Taichi related:
- Wait for Taichi to have native PyTorch tensor view to optimize performance(i.e., no need to copy data back and forth)
- Automatic Batching for
Tin
- waiting for upstream Taichi improvement- workaround for now: do static manual batching, that is to extend fields with one more dimension for batching