#jit-compiler #deep-learning #optimization #kernel-fusion

torsh-jit

JIT compilation and kernel fusion for ToRSh deep learning framework

1 unstable release

0.1.0-alpha.1 Sep 30, 2025

#2011 in Machine learning

35 downloads per month
Used in 2 crates

MIT/Apache

10MB
231K SLoC

torsh-jit

Just-In-Time (JIT) compilation and kernel fusion for the ToRSh deep learning framework.

Overview

The torsh-jit crate provides JIT compilation capabilities for ToRSh, enabling:

  • Kernel Fusion: Automatically combines compatible operations to reduce memory bandwidth
  • Graph Optimization: Applies various optimization passes to improve performance
  • Multiple Backends: Supports CPU (via Cranelift), CUDA, and Metal code generation
  • TorchScript-like API: Compatible interface for ease of migration from PyTorch

Features

Kernel Fusion

  • Element-wise operation fusion (add, mul, relu, etc.)
  • Conv+Activation fusion for common patterns
  • Linear+Activation fusion
  • Reduction operation fusion
  • Custom fusion rules support

Graph Optimizations

  • Dead code elimination
  • Constant folding
  • Common subexpression elimination
  • Algebraic simplification
  • Strength reduction
  • Memory layout optimization

Code Generation

  • Cranelift backend for CPU
  • CUDA PTX generation (planned)
  • Metal shader generation (planned)
  • Interpreter fallback for unsupported operations

Usage

use torsh_jit::{JitCompiler, JitConfig, FusionStrategy};

// Configure JIT compilation
let config = JitConfig {
    fusion_strategy: FusionStrategy::Default,
    enable_optimizations: true,
    max_fusion_size: 8,
    enable_profiling: false,
    target_device: Device::Cpu,
    enable_caching: true,
};

// Create JIT compiler
let mut compiler = JitCompiler::new(config);

// Build computation graph (usually from a model)
let graph = build_model_graph();

// Compile the graph
let compiled_module = compiler.compile(graph)?;

// Execute with inputs
let outputs = compiled_module.execute(&inputs)?;

// Get execution statistics
let stats = compiled_module.stats();
println!("Execution time: {}μs", stats.total_time_us);

Architecture

Computation Graph

The core representation is a directed acyclic graph (DAG) where:

  • Nodes represent operations (MatMul, Conv2d, ReLU, etc.)
  • Edges represent data dependencies
  • Supports complex topologies with multiple inputs/outputs

Fusion Engine

The fusion engine identifies patterns of operations that can be combined:

  • Element-wise chains: neg -> abs -> relu
  • Neural patterns: conv -> bn -> relu
  • Custom patterns via FusionRules

Optimization Pipeline

Multiple optimization passes are applied in sequence:

  1. Dead code elimination
  2. Constant folding
  3. Common subexpression elimination
  4. Algebraic simplification
  5. Strength reduction
  6. Layout optimization

Code Generation

Backend-specific code generators produce optimized kernels:

  • CPU: Uses Cranelift for native code generation
  • CUDA: Generates PTX for NVIDIA GPUs
  • Metal: Generates Metal shaders for Apple Silicon

Performance

The JIT compiler can provide significant speedups by:

  • Reducing memory traffic through kernel fusion
  • Eliminating redundant computations
  • Optimizing memory access patterns
  • Leveraging backend-specific optimizations

Typical improvements:

  • 2-5x speedup for element-wise operation chains
  • 20-40% improvement for conv+activation patterns
  • Near-zero overhead for cached kernels

Future Work

  • MLIR backend integration
  • Dynamic shape support
  • Control flow (if/while) support
  • Distributed JIT compilation
  • Profile-guided optimization
  • Kernel autotuning

License

Licensed under either of:

at your option.

Dependencies

~124MB
~2M SLoC