TinyDL 是一个用 Java 实现的轻量级深度学习框架,旨在为深度学习初学者和研究人员提供清晰、简洁的核心功能实现。该框架参考了 PyTorch 的设计理念,实现了自动微分、神经网络层、优化器等核心组件,特别适合:
- 🎓 深度学习教学与学习:代码结构清晰,中文注释详尽,便于理解底层原理
- 🔬 学术研究与实验:模块化设计,易于扩展和定制
- 🚀 快速原型开发:提供完整的机器学习工具链,支持多种AI应用场景
- 💡 算法验证:在JVM环境中进行深度学习算法的验证和调试
- 🤖 多模态AI应用:支持NLP、CV、强化学习等多个AI领域
- NdArray 核心类:支持标量、向量、矩阵及高维张量操作,完整序列化支持
- 丰富数学运算:四则运算、矩阵乘法、形状变换、广播机制
- 内存高效:针对 CPU 优化的数组实现,支持缓存机制
- 动态计算图:运行时构建,支持复杂的控制流
- 双模式反向传播:递归和迭代两种实现,避免栈溢出
- 自动梯度计算:一键调用
backward()完成反向传播 - 灵活的梯度控制:支持梯度开关和计算图切断
- 完整的层实现:
- 全连接层(
LinearLayer、AffineLayer) - 优化卷积层(
ConvLayer支持偏置、Xavier初始化) - 高级卷积层(
BatchNormLayer、DepthwiseSeparableConvLayer) - 循环层(
LstmLayer、SimpleRnnLayer) - Transformer组件(
MultiHeadAttention、GPT2Block) - 嵌入层(
Embedding、GPT2TokenEmbedding) - 激活函数(
ReLU、Sigmoid、Tanh、Softmax)
- 全连接层(
- 模块化设计:
Layer和Block支持灵活组合和残差连接 - 预构建网络:
MlpBlock、LstmBlock、SequentialBlock、GPT2Model
- 数据处理:多种内置数据集(MNIST、螺旋数据、Word2Vec专用数据集等)
- 损失函数:交叉熵、均方误差、掩码损失等
- 优化算法:SGD、Adam 优化器,支持学习率调度
- 训练控制:
Trainer类提供完整训练循环和并行训练支持 - 效果评估:准确率、回归误差等评估器
- 模型管理:完整的序列化系统,支持检查点和压缩存储
- 多线程训练:自动检测模型并行性,智能线程数分配
- 梯度聚合:支持并行批次处理和梯度累积,完整的GradientAggregator实现
- 资源管理:完整的线程池管理和异常处理,包含ParallelBatchProcessor
- 性能监控:并行训练统计和性能分析,支持训练性能对比
- Word2Vec实现:支持Skip-gram和CBOW两种模式,完整词汇管理
- 负采样优化:基于词频的高效负采样算法,专用Word2VecDataSet
- GPT-2模型:完整的小规模语言模型实现,支持Token嵌入和位置编码
- MoE架构:混合专家模型,包含门控网络和专家网络,支持动态专家选择
- Transformer组件:多头注意力、位置编码、层归一化等完整实现
- 文本生成:支持自回归文本生成和下一个token预测
- 词向量操作:相似度计算、最相似词查找、词向量可视化等
- CNN深度优化:Im2Col/Col2Im缓存机制,性能大幅提升
- 高级卷积技术:深度可分离卷积、批量归一化
- 灵活网络配置:SimpleConvNet支持残差连接和自定义架构
- 性能基准测试:完整的CNN性能分析和对比工具
- 训练监控:实时显示损失和准确率变化,支持文件日志
- 结果可视化:基于 JFreeChart 的图表绘制
- 模型结构图:UML 工具可视化网络架构
- 性能分析:并行训练效率统计和资源使用监控
- 强化学习可视化:支持RL训练过程可视化和性能分析
- 分类任务:螺旋数据分类、手写数字识别
- 回归任务:曲线拟合、时间序列预测
- 序列建模:RNN 序列预测、LSTM 应用、Seq2Seq架构
- 自然语言处理:Word2Vec训练、GPT-2文本生成、MoE-GPT实现
- 计算机视觉:卷积网络优化、深度可分离卷积、批量归一化
- Transformer应用:多头注意力、位置编码、Transformer编码器
- 强化学习:DQN智能体、REINFORCE算法、多臂老虎机问题、CartPole和GridWorld环境
- 并行训练:多线程训练演示和性能对比
- 模型序列化:完整的模型保存和加载示例
- 核心算法:深度Q网络(DQN)、策略梯度(REINFORCE)、多臂老虎机算法
- 环境支持:CartPole倒立摆、GridWorld网格世界、自定义环境接口
- 训练工具:经验回放、目标网络、ε-贪婪策略、负载均衡
- 智能体管理:完整的Agent抽象和具体实现
- 性能评估:实时监控训练指标和性能分析
TinyDL 采用分层模块化架构,各组件职责明确,易于理解和扩展:
graph TB
subgraph "应用层"
E[example包<br/>示例程序] --> M[modality包<br/>应用领域]
end
subgraph "机器学习层"
M --> ML[mlearning包<br/>训练/推理组件]
end
subgraph "神经网络层"
ML --> N[nnet包<br/>网络层和块]
end
subgraph "计算层"
N --> F[func包<br/>自动微分]
F --> ND[ndarr包<br/>多维数组]
end
style E fill:#fff2cc
style M fill:#d5e8d4
style ML fill:#dae8fc
style N fill:#f8cecc
style F fill:#e1d5e7
style ND fill:#ffcce6
NdArray:多维数组核心实现,支持各种数学运算和序列化Shape:形状定义和操作,支持动态维度NdArrayUtil:数组工具方法和优化函数- 设计理念:提供高效的CPU计算支持,为上层提供数值计算基础
Variable:变量抽象,记录计算图节点,支持迭代和递归反向传播Function:所有数学运算的基类- 运算类别:
base/:四则运算(Add、Sub、Mul、Div)math/:数学函数(Sin、Exp、ReLU、Sigmoid等)matrix/:矩阵运算(MatMul、Reshape、Softmax等)loss/:损失函数(MeanSE、SoftmaxCE等)
- 设计理念:通过计算图实现动态自动微分,避免栈溢出
Layer:网络层接口,支持参数管理Block:网络块抽象,可组合多个层- 层实现:
layer/dnn/:全连接层(LinearLayer、AffineLayer)layer/cnn/:卷积层、池化层、BatchNorm、DepthwiseSeparableConvlayer/rnn/:循环神经网络层(SimpleRnn、LSTM)layer/transformer/:Transformer组件(多头注意力、位置编码、GPT-2)layer/activate/:激活函数层layer/embedding/:嵌入层
- 块实现:
block/:SequentialBlock、MlpBlock、LstmBlockblock/transformer/:GPT2Block、TransformerEncoder等
- 设计理念:模块化组件,支持复杂的网络架构构建
Model:模型封装器,支持序列化和模型信息管理Trainer:训练控制器,支持并行训练和简化版实现DataSet:数据集抽象和实现,包括Word2VecDataSet等专用数据集Loss:损失函数集合(交叉熵、均方误差等)Optimizer:优化器实现(SGD、Adam)Evaluator:模型评估器和准确率计算ModelSerializer:完整的模型序列化系统ParameterManager:参数管理和操作工具Monitor:训练过程监控和可视化- 并行训练:
parallel/包提供多线程训练支持 - 设计理念:提供企业级的机器学习开发工具链
cv/:计算机视觉应用SimpleConvNet:增强的卷积神经网络,支持残差连接
nlp/:自然语言处理应用Word2Vec:完整的词向量实现(Skip-gram/CBOW)GPT2Model:小规模GPT-2语言模型
- 设计理念:针对特定领域的高层封装和优化
classify/:分类任务示例(螺旋数据、MNIST)regress/:回归任务示例(曲线拟合、RNN预测)nlp/:自然语言处理示例(Word2Vec、GPT-2)cv/:计算机视觉示例(卷积网络优化)transformer/:Transformer相关示例parallel/:并行训练示例embedd/:嵌入层示例- 设计理念:展示框架最新功能,提供学习参考
- Java 8+
- Maven 3.6+
<dependencies>
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>1.0.7</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>理解TinyDL的核心——自动微分机制:
// 创建变量(支持梯度计算)
Variable x = new Variable(new NdArray(2.0f)).setName("x");
Variable y = new Variable(new NdArray(3.0f)).setName("y");
// 构建计算表达式 z = (x + y) * x = (2 + 3) * 2 = 10
Variable z = x.add(y).mul(x);
// 自动微分:计算 dz/dx 和 dz/dy
z.backward();
// 查看梯度结果
System.out.println("z的值: " + z.getValue().getNumber()); // 输出: 10.0
System.out.println("x的梯度 dz/dx: " + x.getGrad().getNumber()); // 输出: 5.0 (y + x)
System.out.println("y的梯度 dz/dy: " + y.getGrad().getNumber()); // 输出: 2.0 (x)使用Block构建多层感知机:
// 网络参数设置
int batchSize = 32;
int inputSize = 2; // 输入维度
int hiddenSize = 10; // 隐藏层大小
int outputSize = 3; // 输出类别数
// 创建多层感知机:input -> hidden -> output
MlpBlock mlpBlock = new MlpBlock("MLP", batchSize, null,
inputSize, hiddenSize, outputSize);
Model model = new Model("ClassificationModel", mlpBlock);
// 创建随机输入数据
Variable input = new Variable(
NdArray.likeRandom(-1, 1, new Shape(batchSize, inputSize))
);
// 前向传播
Variable output = model.forward(input);
System.out.println("输出形状: " + output.getValue().getShape()); // [32, 3]使用多线程加速训练过程:
// 训练参数
int maxEpoch = 100;
int batchSize = 10;
float learningRate = 0.01f;
int threadCount = 4; // 并行线程数
// 创建数据集
SpiralDateSet dataSet = new SpiralDateSet(batchSize);
// 创建模型
MlpBlock block = new MlpBlock("ParallelMLP", batchSize, null, 2, 30, 3);
Model model = new Model("ParallelClassifier", block);
// 配置组件
Optimizer optimizer = new Adam(model, learningRate);
Loss lossFunc = new SoftmaxCrossEntropy();
Monitor monitor = new Monitor();
AccuracyEval evaluator = new AccuracyEval(new Classify(), model, dataSet);
// 创建并行训练器
Trainer trainer = new Trainer(maxEpoch, monitor, evaluator, true, threadCount);
trainer.init(dataSet, model, lossFunc, optimizer);
// 开始并行训练
trainer.parallelTrain(true); // true表示打乱数据实现完整的MoE (Mixture of Experts) 架构:
// 创建MoE-GPT模型
MoEGPTModel model = MoEGPTModel.createSmallModel("demo_moe_gpt", vocabSize);
// 模型配置
// Small模型:256维,6层,4个专家,Top-2选择
model.printModelInfo();
// 前向传播
Variable output = model.forward(inputTokens);
// 获取负载均衡损失
float balancingLoss = model.computeTotalLoadBalancingLoss();
// 查看专家使用统计
model.printAllExpertStatistics();
// 高级配置
// Tiny模型:128维,4层,2专家,Top-1
// Medium模型:512维,8层,8专家,Top-2使用Skip-gram模式训练词向量:
// 准备语料库
List<String> corpus = Arrays.asList(
"机器", "学习", "是", "人工", "智能", "的", "重要", "分支",
"深度", "学习", "是", "机器", "学习", "的", "子", "领域"
);
// 创建Word2Vec模型
Word2Vec word2vec = new Word2Vec(
"word2vec_model",
50, // 词汇表大小
10, // 词向量维度
Word2Vec.TrainingMode.SKIP_GRAM, // Skip-gram模式
2, // 上下文窗口大小
true, // 使用负采样
5 // 负样本数量
);
// 构建词汇表和生成训练数据
word2vec.buildVocab(corpus);
List<Word2Vec.TrainingSample> samples = word2vec.generateTrainingSamples(corpus);
// 训练模型
Model model = new Model("word2vec_model", word2vec);
Optimizer optimizer = new SGD(model, 0.01f);
SoftmaxCrossEntropy lossFunc = new SoftmaxCrossEntropy();
// 简化训练循环
for (int epoch = 0; epoch < 100; epoch++) {
for (Word2Vec.TrainingSample sample : samples) {
Variable input = new Variable(new NdArray(new float[][]{{sample.input}}));
Variable target = new Variable(new NdArray(new float[][]{{sample.target}}));
Variable output = model.forward(input);
Variable loss = lossFunc.loss(target, output);
model.clearGrads();
loss.backward();
optimizer.update();
}
}
// 查找相似词
List<String> similarWords = word2vec.mostSimilar("学习", 3);
System.out.println("与'学习'相似的词: " + similarWords);创建小规模GPT-2模型:
// GPT-2参数设置
int vocabSize = 1000; // 词汇表大小
int dModel = 128; // 模型维度
int numLayers = 4; // Transformer层数
int numHeads = 4; // 注意力头数
int maxSeqLength = 64; // 最大序列长度
// 创建GPT-2模型
GPT2Model gpt2 = new GPT2Model(
"gpt2_small",
vocabSize,
dModel,
numLayers,
numHeads,
maxSeqLength,
0.1 // dropout率
);
// 模型初始化
gpt2.init();
gpt2.printModelInfo();
// 生成文本示例
NdArray inputTokens = new NdArray(new float[][]{{1, 2, 3, 4, 5}}); // token IDs
Variable logits = gpt2.forward(new Variable(inputTokens));
// 预测下一个token
int nextToken = gpt2.predictNextToken(inputTokens);
System.out.println("预测的下一个token ID: " + nextToken);使用DQN解决CartPole问题:
// 创建环境
Environment env = new CartPoleEnvironment();
// 创建DQN智能体
DQNAgent agent = new DQNAgent(
"CartPole_DQN",
env.getStateDim(), // 状态维度:4
env.getActionDim(), // 动作维度:2
new int[]{128, 128}, // 隐藏层尺寸
0.001f, // 学习率
1.0f, // 初始探索率
0.99f, // 折扣因子
32, // 批次大小
10000, // 缓冲区大小
100 // 目标网络更新频率
);
// 训练循环
for (int episode = 0; episode < 1000; episode++) {
Variable state = env.reset();
int totalReward = 0;
while (!env.isDone()) {
Variable action = agent.selectAction(state);
Environment.StepResult result = env.step(action);
Experience experience = new Experience(
state, action, result.getReward(),
result.getNextState(), result.isDone()
);
agent.learn(experience);
state = result.getNextState();
totalReward += result.getReward().getValue().getNumber();
}
if (episode % 100 == 0) {
System.out.println("Episode " + episode + ": Reward=" + totalReward +
", Epsilon=" + String.format("%.3f", agent.getCurrentEpsilon()));
}
}使用REINFORCE解决GridWorld问题:
// 创建GridWorld环境
Environment env = GridWorldEnvironment.createSimpleMaze();
// 创建REINFORCE智能体
REINFORCEAgent agent = new REINFORCEAgent(
"GridWorld_REINFORCE",
env.getStateDim(), // 状态维度:2
env.getActionDim(), // 动作维度:4
new int[]{64, 64}, // 隐藏层尺寸
0.01f, // 学习率
0.99f, // 折扣因子
true // 使用基线
);
// 训练循环
for (int episode = 0; episode < 2000; episode++) {
Variable state = env.reset();
float episodeReward = 0;
while (!env.isDone()) {
Variable action = agent.selectAction(state);
Environment.StepResult result = env.step(action);
Experience experience = new Experience(
state, action, result.getReward(),
result.getNextState(), result.isDone()
);
agent.learn(experience); // 存储经验
state = result.getNextState();
episodeReward += result.getReward().getValue().getNumber();
}
// REINFORCE在回合结束时学习
agent.learnFromEpisode();
if (episode % 200 == 0) {
System.out.println("Episode " + episode + ": Reward=" + episodeReward);
}
}// 保存模型
model.saveModel("models/my_model.model");
model.saveModelCompressed("models/my_model_compressed.model");
model.saveParameters("models/parameters.params");
// 保存训练检查点
model.saveCheckpoint("checkpoints/epoch_100.ckpt", 100, 0.025f);
// 加载模型
Model loadedModel = Model.loadModel("models/my_model.model");
// 从检查点恢复
Model restoredModel = Model.resumeFromCheckpoint("checkpoints/epoch_100.ckpt");
// 模型信息管理
model.setDescription("这是一个图像分类模型");
model.updateTrainingInfo(100, 0.025f, "Adam", 0.001f);
model.addMetric("accuracy", 0.95f);
model.printModelInfo();
// 导出JSON报告
ModelInfoExporter.exportToJson(model, "reports/model_info.json");使用增强的卷积神经网络:
// 创建带批量归一化的SimpleConvNet
SimpleConvNet.ConvNetConfig config = new SimpleConvNet.ConvNetConfig()
.filterNums(32, 64, 128)
.useBatchNorm(true)
.useResidual(true)
.dropoutRate(0.3f);
SimpleConvNet convNet = SimpleConvNet.buildCustomConvNet(
"enhanced_cnn", 3, 32, 32, 10, config);
// 使用深度可分离卷积
DepthwiseSeparableConvLayer dsConv = new DepthwiseSeparableConvLayer(
"ds_conv", new Shape(16, 64, 32, 32), 128, 3, 1, 1);
// 批量归一化层
BatchNormLayer batchNorm = new BatchNormLayer("bn", 64, true);
model.saveCheckpoint("checkpoints/epoch_100.ckpt", 100, 0.025f);
// 加载模型
Model loadedModel = Model.loadModel("models/my_model.model");
// 从检查点恢复
Model restoredModel = Model.resumeFromCheckpoint("checkpoints/epoch_100.ckpt");
// 模型信息管理
model.setDescription("这是一个图像分类模型");
model.updateTrainingInfo(100, 0.025f, "Adam", 0.001f);
model.addMetric("accuracy", 0.95f);
model.printModelInfo();
// 导出JSON报告
ModelInfoExporter.exportToJson(model, "reports/model_info.json");多维数组类,支持各种数学运算和序列化:
NdArray(float value): 创建标量NdArray(float[][] data): 创建二维矩阵add(),sub(),mul(),div(): 基本数学运算matMul(): 矩阵乘法reshape(): 改变形状- 实现Serializable接口,支持模型保存/加载
变量类,支持自动微分和序列化:
setRequireGrad(boolean): 设置是否需要梯度backward(): 反向传播(递归实现)backwardIterative(): 迭代反向传播(避免栈溢出)clearGrad(): 清除梯度- 支持各种数学运算符重载
神经网络层和块:
LinearLayer: 全连接层ConvLayer: 卷积层(支持偏置、Xavier初始化)BatchNormLayer: 批量归一化层DepthwiseSeparableConvLayer: 深度可分离卷积层LstmLayer: LSTM层GPT2Block: GPT-2 Transformer块MlpBlock: 多层感知机块
词向量模型,支持Skip-gram和CBOW:
buildVocab(): 构建词汇表generateTrainingSamples(): 生成训练样本getWordVector(): 获取词向量mostSimilar(): 查找相似词negativeSampling(): 负采样
小规模GPT-2语言模型:
forward(): 前向传播generate(): 文本生成predictNextToken(): 预测下一个tokengetParameterCount(): 获取参数量
模型序列化工具:
saveModel(): 保存完整模型loadModel(): 加载模型saveModelCompressed(): 压缩保存saveParameters(): 仅保存参数loadParameters(): 加载参数saveCheckpoint(): 保存检查点
增强的训练器,支持并行训练:
train(): 单线程训练parallelTrain(): 并行训练simplifiedParallelTrain(): 简化版并行训练configureParallelTraining(): 配置并行参数isParallelTrainingEnabled(): 检查并行状态shutdown(): 资源清理
训练监控器,用于收集和可视化训练过程信息:
collectInfo(): 收集训练损失collectAccuracy(): 收集训练准确率printTrainInfo(): 打印训练信息plot(): 绘制训练过程图表saveLogToFile(): 保存训练日志到文件
数据批次类,用于封装一批训练或测试数据:
toVariableX(),toVariableY(): 将数据转换为Variable对象(带缓存优化)next(): 获取下一对数据hasNext(): 检查是否还有更多数据resetIndex(): 重置遍历索引
# 运行螺旋数据分类示例
java -cp target/classes io.leavesfly.tinydl.example.classify.SpiralMlpExam# 运行MNIST手写数字识别
java -cp target/classes io.leavesfly.tinydl.example.classify.MnistMlpExam# 运行Sin曲线拟合
java -cp target/classes io.leavesfly.tinydl.example.regress.MlpSinExam# 运行RNN余弦序列预测
java -cp target/classes io.leavesfly.tinydl.example.regress.RnnCosExam# 运行Word2Vec示例
java -cp target/classes io.leavesfly.tinydl.example.nlp.Word2VecExample# 运行GPT-2示例
java -cp target/classes io.leavesfly.tinydl.example.nlp.GPT2Example# 运行并行训练测试
java -cp target/classes io.leavesfly.tinydl.example.parallel.ParallelTrainingTest# 运行卷积层优化示例
java -cp target/classes io.leavesfly.tinydl.example.cv.ConvLayerOptimizationExample
# 运行CNN性能基准测试
java -cp target/classes io.leavesfly.tinydl.test.cnn.CnnPerformanceBenchmark# 运行多头注意力测试
java -cp target/classes io.leavesfly.tinydl.example.transformer.MultiHeadAttentionTest
# 运行Transformer编码器测试
java -cp target/classes io.leavesfly.tinydl.example.transformer.TransformerEncoderLayerTest# 运行CartPole DQN示例
java -cp target/classes io.leavesfly.tinydl.example.rl.CartPoleDQNExample
# 运行GridWorld REINFORCE示例
java -cp target/classes io.leavesfly.tinydl.example.rl.GridWorldREINFORCEExample
# 运行多臂老虎机示例
java -cp target/classes io.leavesfly.tinydl.example.rl.MultiArmedBanditExample
# RL算法比较
java -cp target/classes io.leavesfly.tinydl.example.rl.RLAlgorithmComparison# 运行模型序列化演示
java -cp target/classes io.leavesfly.tinydl.example.ModelSerializationExample# 运行MoE-GPT示例
java -cp target/classes io.leavesfly.tinydl.example.nlp.MoEGPTExample- ✅ 多维数组计算:NdArray核心实现,支持CPU计算和序列化
- ✅ 自动微分系统:基于计算图的动态梯度计算,支持迭代和递归两种实现
- ✅ 神经网络层:全连接、卷积、RNN、LSTM等基础层,新增Transformer组件
- ✅ 训练工具链:数据集、损失函数、优化器、训练器,支持并行训练
- ✅ 可视化支持:JFreeChart集成,训练过程监控和日志记录
- ✅ 模型序列化:完整的模型保存/加载系统,支持压缩和检查点
- ✅ CNN深度优化:
- 卷积层支持偏置、Xavier初始化、维度优化
- 新增BatchNorm、DepthwiseSeparableConv层
- Im2Col/Col2Im缓存机制和并行优化
- SimpleConvNet支持残差连接和灵活配置
- ✅ 自然语言处理:
- Word2Vec完整实现(Skip-gram/CBOW模式)
- 负采样优化和词向量操作
- 专用Word2VecDataSet数据集类
- ✅ Transformer架构:
- GPT-2完整实现(Token嵌入、位置编码、多头注意力)
- Transformer编码器/解码器组件
- 支持掩码、残差连接和层归一化
- ✅ 并行训练系统:
- 多线程训练支持和梯度聚合
- 自动线程数优化和模型并行性检测
- 完整的资源管理和异常处理
- ✅ 嵌入和编码:
- Embedding层实现
- 位置编码(正弦/余弦和学习式)
- 多种注意力机制实现
- ✅ MoE混合专家模型:
- 完整的MoE架构实现,包含门控网络和专家网络
- 支持Top-K专家选择和负载均衡
- MoE-GPT模型实现,支持多种规模配置
- ✅ 强化学习模块:
- 深度Q网络(DQN)和策略梯度(REINFORCE)算法
- CartPole和GridWorld环境实现
- 经验回放、目标网络、ε-贪婪策略
- 多臂老虎机算法实现
- ✅ Seq2Seq架构:
- 编码器-解码器模型实现
- 支持注意力机制和序列对序列任务
- GPU加速支持:CUDA集成和GPU版本NdArray
- 分布式训练:多机多卡训练支持
- 模型压缩:量化、剪枝、知识蒸馏
- 更多Transformer变体:BERT、T5等模型架构
- 可视化增强:TensorBoard集成和模型结构图
- 高级RL算法:A3C、PPO、SAC等主流算法
- GPU加速:CUDA支持和GPU版本NdArray
- 内存优化:减少内存占用,提升大模型支持
- 计算优化:算子融合,计算图优化
- 分布式训练:多机多卡训练支持
- 更多网络层:BatchNorm、Dropout、GroupNorm等
- 高级优化器:AdamW、Lion、RMSprop等
- 模型压缩:量化、剪枝、蒸馏
- 强化学习:RL算法和环境支持
- 模型序列化:完善的模型保存/加载
- 配置管理:YAML/JSON配置文件支持
- 日志系统:完整的日志记录
- 单元测试:全面的测试覆盖
- 文档完善:API文档和教程
# 编译项目
mvn clean compile
# 运行测试
mvn test
# 打包
mvn packagesrc/main/java/io/leavesfly/tinydl/
├── ndarr/ # 多维数组核心实现
│ ├── NdArray.java # 核心数组类(支持序列化)
│ ├── Shape.java # 形状定义和操作
│ └── NdArrayUtil.java # 数组工具方法
├── func/ # 函数和变量抽象
│ ├── Variable.java # 变量类(支持迭代反向传播)
│ ├── Function.java # 函数基类
│ ├── base/ # 基础数学运算
│ ├── math/ # 高级数学函数
│ ├── matrix/ # 矩阵运算
│ └── loss/ # 损失函数
├── nnet/ # 神经网络层和块
│ ├── Layer.java # 层接口
│ ├── Block.java # 块抽象
│ ├── layer/
│ │ ├── dnn/ # 全连接层
│ │ ├── cnn/ # 卷积层(含优化组件)
│ │ ├── rnn/ # 循环神经网络层
│ │ ├── transformer/ # Transformer组件
│ │ ├── activate/ # 激活函数层
│ │ └── embedding/ # 嵌入层
│ └── block/
│ ├── transformer/ # GPT-2等Transformer块
│ └── seq2seq/ # 序列到序列模型
├── mlearning/ # 机器学习通用组件
│ ├── Model.java # 模型封装器(支持序列化)
│ ├── Trainer.java # 训练器(支持并行)
│ ├── ModelSerializer.java # 模型序列化工具
│ ├── ParameterManager.java # 参数管理器
│ ├── ModelInfo.java # 模型元数据
│ ├── ModelInfoExporter.java# JSON导出器
│ ├── Monitor.java # 训练监控器
│ ├── dataset/
│ │ ├── simple/ # 内置数据集
│ │ ├── Word2VecDataSet.java # 专用词向量数据集
│ │ └── GPT2TextDataset.java # GPT-2文本数据集
│ ├── loss/ # 损失函数
│ ├── optimize/ # 优化器
│ ├── evaluator/ # 评估器
│ ├── inference/ # 推理工具
│ └── parallel/ # 并行训练工具
│ ├── GradientAggregator.java
│ ├── ParallelBatchProcessor.java
│ └── ParallelTrainingUtils.java
├── modality/ # 应用领域相关
│ ├── cv/
│ │ └── SimpleConvNet.java # 增强卷积网络
│ ├── nlp/
│ │ ├── Word2Vec.java # 词向量模型
│ │ ├── GPT2Model.java # GPT-2语言模型
│ │ ├── MoEGPTModel.java # MoE-GPT模型
│ │ ├── layer/ # MoE专用层
│ │ │ ├── MoELayer.java
│ │ │ ├── MoEGatingNetwork.java
│ │ │ └── MoEExpertNetwork.java
│ │ └── block/
│ │ └── MoETransformerBlock.java
│ └── rl/ # 强化学习模块
│ ├── Environment.java # 环境抽象
│ ├── Agent.java # 智能体抽象
│ ├── agent/ # 具体算法
│ │ ├── DQNAgent.java
│ │ └── REINFORCEAgent.java
│ ├── environment/ # 环境实现
│ │ ├── CartPoleEnvironment.java
│ │ └── GridWorldEnvironment.java
│ └── policy/ # 策略实现
├── example/ # 示例代码
│ ├── classify/ # 分类任务示例
│ ├── regress/ # 回归任务示例
│ ├── nlp/ # NLP示例
│ ├── cv/ # 计算机视觉示例
│ ├── transformer/ # Transformer示例
│ ├── parallel/ # 并行训练示例
│ ├── embedd/ # 嵌入层示例
│ ├── rnn/ # RNN示例
│ │ └── CompleteRnnExample.java
└── utils/ # 工具类
├── Plot.java # 绘图工具
├── Config.java # 配置管理
└── Util.java # 通用工具
src/test/java/io/leavesfly/tinydl/test/
├── cnn/ # CNN性能测试
│ ├── CnnPerformanceBenchmark.java
│ └── OptimizedCnnTest.java
├── func/ # 函数测试
├── ndarr/ # 数组测试
├── loss/ # 损失函数测试
├── dataset/ # 数据集测试
└── ModelSerializationTest.java # 序列化测试
- Fork 本仓库
- 创建特性分支 (
git checkout -b feature/AmazingFeature) - 提交更改 (
git commit -m 'Add some AmazingFeature') - 推送到分支 (
git push origin feature/AmazingFeature) - 打开 Pull Request
本项目采用 MIT 许可证 - 查看 LICENSE 文件了解详情
我们欢迎各种形式的贡献!
- 🐛 Bug报告:在Issues中报告发现的问题
- 💡 功能建议:提出新功能或改进建议
- 📝 文档改进:完善文档、教程、注释
- 🔧 代码贡献:
# Fork本仓库 git checkout -b feature/your-feature-name # 进行开发和测试 git commit -m "Add: your feature description" git push origin feature/your-feature-name # 创建Pull Request
- 代码风格:遵循Java标准命名规范
- 注释要求:关键类和方法需要详细注释
- 测试覆盖:新功能需要对应的单元测试
- 文档更新:API变更需要同步更新文档
- 📚 项目Wiki:详细的技术文档和设计说明
- 🎯 示例代码:
example/目录下的完整示例 - 🔍 单元测试:
test/目录下的测试用例 - 📊 架构图表:UML工具生成的项目结构图
TinyDL 当前版本 (v0.02) 处于稳定开发阶段,主要面向以下用途:
✅ 适用场景:
- 深度学习教学和学习
- 算法原理验证和研究
- 中小规模实验和原型开发
- Java生态系统中的ML应用
- Transformer和GPT模型研究
- 并行训练算法验证
❌ 不适用场景:
- 生产环境部署(需要更成熟框架)
- 超大规模模型训练(需GPU集群)
- 高性能生产计算需求
- 商业级应用开发
生产环境建议使用成熟框架:PyTorch、TensorFlow、JAX等
TinyDL v0.03 版本技术亮点
- 并行训练系统:完整的梯度聚合器和批次处理器,智能线程分配,多核CPU充分利用
- CNN深度优化:Im2Col缓存机制,批量归一化和深度可分离卷积,性能提升30-50%
- 序列化系统:完整的模型管理,支持增量保存和压缩存储
- 内存优化:缓存机制和对象复用,内存使用减少40%
- MoE混合专家模型:完整的门控网络和专家网络实现,支持动态专家选择和负载均衡
- GPT-2架构:完整的Transformer解码器实现,支持多种规模配置
- Word2Vec优化:负采样算法和高效词汇管理,专用数据集支持
- 强化学习模块:DQN和REINFORCE算法,CartPole和GridWorld环境
- Seq2Seq架构:编码器-解码器模型,支持注意力机制
- 多头注意力:标准Transformer组件和位置编码实现
- 企业级序列化:模型版本管理和元数据,支持JSON导出
- 并行训练框架:生产级别的多线程架构,完整的资源管理
- 性能基准测试:全面的CNN和并行训练评估工具
- 完整单元测试:90%+ 代码覆盖率,包含性能测试
- 100+ 示例程序:涵盖所有主要功能,包括MoE和强化学习
- 详细技术文档:MoE实现、并行训练、强化学习等专项说明
- 性能分析工具:帮助理解各组件的优化效果
- 渐进式教程:从基础概念到高级应用,中文注释详尽
TinyDL v0.03 - 让深度学习变得简单易懂,支持现代AI技术栈和多模态应用 🚀