Skip to content

Aadduri/clean code#258

Open
abhinadduri wants to merge 18 commits intomainfrom
aadduri/clean_code
Open

Aadduri/clean code#258
abhinadduri wants to merge 18 commits intomainfrom
aadduri/clean_code

Conversation

@abhinadduri
Copy link
Collaborator

No description provided.

@abhinadduri abhinadduri requested a review from a team as a code owner February 24, 2026 04:20
@gemini-code-assist
Copy link

Summary of Changes

Hello @abhinadduri, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the model architecture by removing several older or less-utilized models, consolidating the focus on the StateTransitionPerturbationModel. A major new feature is the introduction of a Negative Binomial loss function, enhancing the model's capability to handle count-based gene expression data. Concurrently, various deprecated features have been removed, and the training and prediction pipelines have been updated to support the new loss and provide more precise metric logging. These changes, coupled with an updated optimizer configuration and comprehensive documentation, mark a substantial step towards a cleaner, more robust, and feature-rich codebase.

Highlights

  • Model Consolidation and Removal: The CPA, scVI, scGPT, and OldNeuralOT models, along with their associated configurations, dataset classes, and utility functions, have been removed from the codebase. This streamlines the project by focusing on the StateTransitionPerturbationModel.
  • Negative Binomial (NB) Loss Implementation: A new Negative Binomial (NB) loss function has been introduced for the StateTransitionPerturbationModel, specifically designed for count data. This includes validation checks for output_space compatibility, library size estimation from control cells, and conditional disabling of the gene decoder when NB loss is active.
  • Deprecation of Model Features: Several model features have been deprecated or removed, including residual_decoder, confidence_token, use_batch_token, batch_predictor, mmd_num_chunks, randomize_mmd_chunks, and the combined sinkhorn+energy loss (se). Warnings are now logged if these legacy parameters are encountered.
  • Enhanced Training and Prediction Logic: The training and prediction steps have been refactored to support the new NB loss, including specific handling for count data, library size calculations, and conditional clipping of outputs. Loss logging now uses more granular keys like train/expression_loss and val/embedding_loss.
  • Improved Optimizer Configuration: The StateTransitionPerturbationModel now includes a dedicated configure_optimizers method, offering support for AdamW and flexible cosine learning rate decay schedules.
  • Documentation and Version Update: The README.md has been updated with a detailed 'ST training contract' section, clarifying the interaction between embed_key, output_space, decoder usage, and logged metrics. The project version in pyproject.toml has been incremented to 1.0.0.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • .gitignore
    • Added logs/ and debugging/ directories to be ignored.
  • README.md
    • Added a new section detailing the 'ST training contract' for embed_key, output_space, decoder, and metrics.
  • pyproject.toml
    • Updated the project version from 0.10.4 to 1.0.0.
  • src/state/_cli/_tx/_infer.py
    • Implemented checks and validation for nb_loss_enabled in conjunction with output_space.
    • Adjusted counts_expected logic to incorporate nb_loss_enabled.
    • Modified output clipping to apply only to legacy decoder outputs, leaving NB count outputs unclipped.
  • src/state/_cli/_tx/_predict.py
    • Integrated nb_loss_enabled checks and validation for output_space.
    • Ensured exp_counts aligns with is_log1p when nb_loss is active, adding relevant warnings.
    • Removed the import and usage of OldNeuralOTPerturbationModel.
    • Updated pseudobulk processing to use use_count_outputs and conditionally skip clipping based on nb_loss_enabled.
    • Modified pdex_kwargs to dynamically use metrics_is_log1p.
  • src/state/_cli/_tx/_train.py
    • Removed imports and logic specific to json, Path, MixedPrecision, and model-specific configurations for scGPT, CPA, and scVI.
    • Added nb_loss_enabled validation and logic to enforce store_raw_basal=True under specific conditions.
    • Introduced checkpoint_monitor_metric for more flexible checkpoint monitoring.
    • Updated decoder creation logic to disable it when nb_loss is enabled.
    • Removed residual_decoder from LatentToGeneDecoder initialization parameters.
  • src/state/configs/model/cpa.yaml
    • Removed the CPA model configuration file.
  • src/state/configs/model/old_neuralot.yaml
    • Removed the old_neuralot model configuration file.
  • src/state/configs/model/pertsets.yaml
    • Removed confidence_token, residual_decoder, and use_batch_token parameters.
  • src/state/configs/model/pseudobulk.yaml
    • Removed the residual_decoder parameter.
  • src/state/configs/model/scgpt-chemical.yaml
    • Removed the scgpt-chemical model configuration file.
  • src/state/configs/model/scgpt-genetic.yaml
    • Removed the scgpt-genetic model configuration file.
  • src/state/configs/model/scvi.yaml
    • Removed the scvi model configuration file.
  • src/state/configs/model/state.yaml
    • Removed confidence_token, residual_decoder, use_batch_token, mmd_num_chunks, and randomize_mmd_chunks parameters.
    • Added the nb_loss parameter.
  • src/state/configs/model/state_lg.yaml
    • Removed the residual_decoder parameter.
  • src/state/configs/model/state_sm.yaml
    • Removed residual_decoder, mmd_num_chunks, and randomize_mmd_chunks parameters.
  • src/state/configs/model/tahoe_best.yaml
    • Removed the residual_decoder parameter.
  • src/state/configs/model/tahoe_llama_212693232.yaml
    • Removed the residual_decoder parameter.
  • src/state/configs/model/tahoe_llama_62089464.yaml
    • Removed the residual_decoder parameter.
  • src/state/configs/training/cpa.yaml
    • Removed the CPA training configuration file.
  • src/state/configs/training/default.yaml
    • Added new optimizer configuration parameters: optimizer, use_cosine_decay, max_lr, lr_decay_steps, and max_lr_fraction.
  • src/state/configs/training/scgpt.yaml
    • Removed the scgpt training configuration file.
  • src/state/configs/training/scvi.yaml
    • Removed the scvi training configuration file.
  • src/state/tx/callbacks/batch_speed_monitor.py
    • Removed detailed logging of min, max, average, coefficient of variation, and max/min ratio for batch times.
  • src/state/tx/callbacks/model_flops_utilization.py
    • Removed logging of cell_sets_per_sec.
  • src/state/tx/data/dataset/init.py
    • Removed the import of scGPTPerturbationDataset.
  • src/state/tx/data/dataset/scgpt_perturbation_dataset.py
    • Removed the scgpt_perturbation_dataset file.
  • src/state/tx/models/init.py
    • Removed the import of OldNeuralOTPerturbationModel.
  • src/state/tx/models/base.py
    • Removed residual_decoder from LatentToGeneDecoder initialization and logic.
    • Introduced _sanitize_decoder_cfg to handle deprecated residual_decoder parameter.
    • Refactored on_load_checkpoint to simplify decoder configuration and enforce decoder_cfg presence.
    • Added new methods (_main_loss_is_expression, _train_main_loss_key, _val_main_loss_key, etc.) for standardized loss logging.
    • Updated training_step, validation_step, and test_step to use the new standardized loss keys.
  • src/state/tx/models/context_mean.py
    • Updated training_step to use the new _train_main_loss_key for logging.
  • src/state/tx/models/cpa/init.py
    • Removed the init.py file for CPA.
  • src/state/tx/models/cpa/_base_modules.py
    • Removed the _base_modules file.
  • src/state/tx/models/cpa/_callbacks.py
    • Removed the _callbacks file.
  • src/state/tx/models/cpa/_dists.py
    • Removed the _dists file.
  • src/state/tx/models/cpa/_model.py
    • Removed the _model file.
  • src/state/tx/models/cpa/_module.py
    • Removed the _module file.
  • src/state/tx/models/cpa/_task.py
    • Removed the _task file.
  • src/state/tx/models/decoder_only.py
    • Updated docstring to reflect changes in model lineage.
    • Modified training_step, validation_step, and on_validation_batch_end to use new loss keys.
    • Removed decoder_test_loss logging from test_step.
  • src/state/tx/models/decoders.py
    • Removed the decoders file.
  • src/state/tx/models/embed_sum.py
    • Updated docstring to reflect current model capabilities.
  • src/state/tx/models/old_neural_ot.py
    • Removed the old_neural_ot model file.
  • src/state/tx/models/perturb_mean.py
    • Updated training_step to use the new _train_main_loss_key for logging.
  • src/state/tx/models/pseudobulk.py
    • Removed FinetuneVCICountsDecoder import and related logic.
    • Updated _maybe_concat_batch logic.
    • Modified training_step, validation_step, test_step, and predict_step to use new loss keys and simplified batch concatenation logic.
  • src/state/tx/models/scgpt/init.py
    • Removed the init.py file for scgpt.
  • src/state/tx/models/scgpt/dsbn.py
    • Removed the dsbn file.
  • src/state/tx/models/scgpt/gene_tokenizer.py
    • Removed the gene_tokenizer file.
  • src/state/tx/models/scgpt/generation_model.py
    • Removed the generation_model file.
  • src/state/tx/models/scgpt/grad_reverse.py
    • Removed the grad_reverse file.
  • src/state/tx/models/scgpt/lightning_model.py
    • Removed the lightning_model file.
  • src/state/tx/models/scgpt/loss.py
    • Removed the loss file.
  • src/state/tx/models/scgpt/model.py
    • Removed the model file.
  • src/state/tx/models/scgpt/utils.py
    • Removed the utils file.
  • src/state/tx/models/scvi/init.py
    • Removed the init.py file for scvi.
  • src/state/tx/models/scvi/_base_modules.py
    • Removed the _base_modules file.
  • src/state/tx/models/scvi/_callbacks.py
    • Removed the _callbacks file.
  • src/state/tx/models/scvi/_dists.py
    • Removed the _dists file.
  • src/state/tx/models/scvi/_model.py
    • Removed the _model file.
  • src/state/tx/models/scvi/_module.py
    • Removed the _module file.
  • src/state/tx/models/scvi/_task.py
    • Removed the _task file.
  • src/state/tx/models/state_transition.py
    • Removed anndata import, CombinedLoss, and ConfidenceToken classes.
    • Simplified __init__ by removing regularization, mmd_num_chunks, randomize_mmd_chunks, and adding nb_loss parameters.
    • Added new helper methods for handling Negative Binomial loss, including count space conversion and library size calculation.
    • Refactored forward, training_step, validation_step, test_step, and predict_step to integrate NB loss logic and remove deprecated features.
    • Implemented a dedicated configure_optimizers method supporting AdamW and cosine learning rate decay.
  • src/state/tx/utils/init.py
    • Updated get_checkpoint_callbacks to accept and utilize a monitor_metric parameter.
    • Removed references to OldNeuralOTPerturbationModel, CPAPerturbationModel, SCVIPerturbationModel, and scGPTForPerturbation from get_lightning_module.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request performs a significant cleanup of the codebase, removing several legacy model implementations (CPA, scVI, scGPT, OldNeuralOT) and focusing on the StateTransitionPerturbationModel. Key enhancements include the introduction of Negative Binomial (NB) loss support for modeling count data, the addition of cosine learning rate decay, and more flexible optimizer configuration. The refactoring of the decoder logic and checkpoint loading in the base class improves maintainability. However, there is a critical issue in the residual addition logic within StateTransitionPerturbationModel that could lead to shape mismatches during training or inference when the input and output dimensions differ.

Comment on lines 432 to 435
if self.predict_residual and self.output_space == "all":
# Project control_cells to hidden_dim space to match res_pred
# control_cells_hidden = self.project_to_hidden(control_cells)
# treat the actual prediction as a residual sum to basal
out_pred = self.project_out(res_pred) + basal

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential shape mismatch here when predict_residual is enabled and output_space is set to 'all'. In this case, self.project_out(res_pred) will have a feature dimension of output_dim (which is gene_dim for the full transcriptome), while basal has a feature dimension of input_dim (which could be hvg_dim or an embedding dimension). If input_dim != output_dim, the addition will fail with a runtime error. Consider checking if the dimensions match before performing the addition, or projecting basal to the output space.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant