diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 2a78362e9a..37ffeaeba8 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -141,6 +141,24 @@ from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, ) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm import ( + DeepSeekR1QwenCausalLM as DeepSeekR1Qwen2CausalLM, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm import ( + DeepSeekR1QwenCausalLM as DeepSeekR1QwenCausalLM, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm_preprocessor import ( + DeepSeekR1QwenCausalLMPreprocessor as DeepSeekR1Qwen2CausalLMPreprocessor, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm_preprocessor import ( + DeepSeekR1QwenCausalLMPreprocessor as DeepSeekR1QwenCausalLMPreprocessor, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( + DeepSeekR1QwenTokenizer as DeepSeekR1Qwen2Tokenizer, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( + DeepSeekR1QwenTokenizer as DeepSeekR1QwenTokenizer, +) from keras_hub.src.models.densenet.densenet_backbone import ( DenseNetBackbone as DenseNetBackbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 79b6efa192..59d713cfdb 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -22,6 +22,12 @@ from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer as DebertaV3Tokenizer, ) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( + DeepSeekR1QwenTokenizer as DeepSeekR1Qwen2Tokenizer, +) +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( + DeepSeekR1QwenTokenizer as DeepSeekR1QwenTokenizer, +) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer as DistilBertTokenizer, ) diff --git a/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm.py b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm.py new file mode 100644 index 0000000000..4be16453db --- /dev/null +++ b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm.py @@ -0,0 +1,300 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_causal_lm_preprocessor import ( + DeepSeekR1QwenCausalLMPreprocessor, +) +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + [ + "keras_hub.models.DeepSeekR1QwenCausalLM", + "keras_hub.models.DeepSeekR1Qwen2CausalLM", + ] +) +class DeepSeekR1QwenCausalLM(CausalLM): + backbone_cls = QwenBackbone + preprocessor_cls = DeepSeekR1QwenCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `DeepSeekR1QwenCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `QwenCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `QwenBackbone` and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + qwen_lm = keras_hub.models.QwenCausalLM.from_preset("qwen2.5_0.5b_en") + generations = qwen_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = qwen_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = qwen_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm_preprocessor.py b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm_preprocessor.py new file mode 100644 index 0000000000..333507356b --- /dev/null +++ b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_causal_lm_preprocessor.py @@ -0,0 +1,20 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.deepseek_r1.deepseek_r1_qwen_tokenizer import ( + DeepSeekR1QwenTokenizer, +) +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone + + +@keras_hub_export( + [ + "keras_hub.models.DeepSeekR1QwenCausalLMPreprocessor", + "keras_hub.models.DeepSeekR1Qwen2CausalLMPreprocessor", + ] +) +class DeepSeekR1QwenCausalLMPreprocessor(CausalLMPreprocessor): + backbone_cls = QwenBackbone + tokenizer_cls = DeepSeekR1QwenTokenizer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_tokenizer.py b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_tokenizer.py new file mode 100644 index 0000000000..5f105660ab --- /dev/null +++ b/keras_hub/src/models/deepseek_r1/deepseek_r1_qwen_tokenizer.py @@ -0,0 +1,54 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.DeepSeekR1QwenTokenizer", + "keras_hub.tokenizers.DeepSeekR1Qwen2Tokenizer", + "keras_hub.models.DeepSeekR1QwenTokenizer", + "keras_hub.models.DeepSeekR1Qwen2Tokenizer", + ] +) +class DeepSeekR1QwenTokenizer(BytePairTokenizer): + """Tokenizer for DeepSeekR1-Distilled Qwen models. + + This tokenizer implements byte-pair encoding (BPE) for DeepSeekR1-Distilled + Qwen models, handling special tokens like BOS (beginning of sequence) + and EOS (end of sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = QwenBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|end▁of▁sentence|>" + self._add_special_token(eos_token, "eos_token") + bos_token = "<|begin▁of▁sentence|>" + self._add_special_token(bos_token, "bos_token") + + self.end_token_id = 151643 + self.start_token_id = 151646 + self.start_token = None + self.pad_token_id = 0 + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/tools/checkpoint_conversion/convert_deepseek_r1_qwen_distil_checkpoints.py b/tools/checkpoint_conversion/convert_deepseek_r1_qwen_distil_checkpoints.py new file mode 100644 index 0000000000..11a28c4391 --- /dev/null +++ b/tools/checkpoint_conversion/convert_deepseek_r1_qwen_distil_checkpoints.py @@ -0,0 +1,136 @@ +import os +import traceback + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices + +import numpy as np +import torch +from absl import app +from absl import flags + +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) + +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +import keras_hub # noqa: E402 + +PRESET_MAP = { + "deepseekr1-dist-qwen-1.5B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseekr1-dist-qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "deepseekr1-dist-qwen-14B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "deepseekr1-dist-qwen-32B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def test_model(keras_hub_model, hf_model, hf_model_tokenizer): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + print(keras_hub_params, hf_params) + # assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() + + keras_hub_inputs = keras_hub_model.preprocessor( + ["What is Keras?"], sequence_length=6 + )[0] + keras_hub_inputs = {k: v for k, v in keras_hub_inputs.items()} + + keras_hub_logits = keras_hub_model(keras_hub_inputs) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerence since bfloat16 is used as the default dtype for Qwen + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-2 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = ( + keras_hub.models.DeepSeekR1QwenCausalLMPreprocessor(keras_hub_tokenizer) + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=6 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + np.testing.assert_equal(keras_hub_output, hf_output) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, device_map=device + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + + keras_hub_tokenizer = keras_hub.models.DeepSeekR1Qwen2Tokenizer.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_model = keras_hub.models.DeepSeekR1QwenCausalLM.from_preset( + f"hf://{hf_preset}" + ) + + print("\n-> Huggingface model and tokenizer loaded") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + # Save models + keras_hub_model.save_to_preset(preset) + + +def sanity(_): + hf_preset = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + backbone = keras_hub.models.Qwen2Backbone.from_preset(f"hf://{hf_preset}") + keras_hub_tokenizer = keras_hub.models.DeepSeekR1Qwen2Tokenizer.from_preset( + f"hf://{hf_preset}" + ) + keras_hub_preprocessor = ( + keras_hub.models.DeepSeekR1QwenCausalLMPreprocessor(keras_hub_tokenizer) + ) + keras_hub_model = keras_hub.models.DeepSeekR1QwenCausalLM( + backbone=backbone, preprocessor=keras_hub_preprocessor + ) + + print(keras_hub_model.generate("What is Keras?", max_length=8)) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + # app.run(main) + app.run(sanity) diff --git a/tools/checkpoint_conversion/convert_qwen_checkpoints.py b/tools/checkpoint_conversion/convert_qwen_checkpoints.py index 1cbd355081..4db5d35f6f 100644 --- a/tools/checkpoint_conversion/convert_qwen_checkpoints.py +++ b/tools/checkpoint_conversion/convert_qwen_checkpoints.py @@ -24,6 +24,7 @@ "qwen2.5_7b_en": "Qwen/Qwen2.5-7B", "qwen2.5_3b_en": "Qwen/Qwen2.5-3B", "qwen2.5_instruct_0.5b_en": "Qwen/Qwen2.5-0.5B-Instruct", + "qwen2.5_1.5b_en": "Qwen/Qwen2.5-1.5B", } FLAGS = flags.FLAGS