Think about data types - use torch.get_default_dtype() or deduce data type from inputs
Think about data types - use torch.get_default_dtype() or deduce data type from inputs