数据模型与类型
triton_ops.models 模块集中放置了输入规格、性能指标、调优结果以及 FP8 格式工具等 dataclass。
TensorSpec
TensorSpec(
shape: tuple[int, ...],
dtype: torch.dtype,
device: str = "cuda",
contiguous: bool = True,
)
核心方法:
validate(tensor) -> boolcreate_tensor(fill_value=None) -> torch.Tensor
输入规格 dataclass
可用构造器:
RMSNormRoPEInput.from_shapes(...)GatedMLPInput.from_shapes(...)FP8GEMMInput.from_shapes(...)
这些类很适合用于测试、脚手架和样例生成。
关于 FP8GEMMInput 的提醒:
- 当 PyTorch 支持原生 float8 dtype 时,它会优先选择 float8。
- 但仓库当前维护的运行路径仍然是量化页中描述的
uint8兼容表示。
KernelMetrics
KernelMetrics(
latency_ms: float,
throughput_tflops: float,
bandwidth_gbps: float,
bandwidth_utilization: float,
)
这是 benchmark 与 autotuning 共享的性能指标容器。
TuningResult
TuningResult(
best_config: dict[str, Any],
metrics: KernelMetrics,
all_results: list[tuple[dict[str, Any], KernelMetrics]] = [],
problem_size: tuple[int, ...] | None = None,
device: str | None = None,
)
用于保存最优配置、指标以及可选的完整搜索结果。
FP8Format
FP8Format 保存了 FP8 E4M3 相关的常量和工具函数。
常用成员:
FP8Format.max_valueFP8Format.min_normalFP8Format.compute_scale(tensor)FP8Format.compute_scale_per_channel(tensor, dim=0)FP8Format.is_in_range(tensor, scale)