| \n", " | date | \n", "year | \n", "month | \n", "dayofyear | \n", "t | \n", "influencer_spend | \n", "shipping_threshold | \n", "intercept | \n", "trend | \n", "cs | \n", "cc | \n", "seasonality | \n", "epsilon | \n", "y | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "2019-04-01 | \n", "2019 | \n", "4 | \n", "91 | \n", "0 | \n", "0.918883 | \n", "25.0 | \n", "2.0 | \n", "0.778279 | \n", "-0.012893 | \n", "0.006446 | \n", "-0.003223 | \n", "-0.118826 | \n", "2.561363 | \n", "
| 1 | \n", "2019-04-08 | \n", "2019 | \n", "4 | \n", "98 | \n", "1 | \n", "0.230898 | \n", "25.0 | \n", "2.0 | \n", "0.795664 | \n", "0.225812 | \n", "-0.113642 | \n", "0.056085 | \n", "0.064977 | \n", "2.264874 | \n", "
| 2 | \n", "2019-04-15 | \n", "2019 | \n", "4 | \n", "105 | \n", "2 | \n", "0.254486 | \n", "25.0 | \n", "2.0 | \n", "0.812559 | \n", "0.451500 | \n", "-0.232087 | \n", "0.109706 | \n", "-0.020269 | \n", "1.998208 | \n", "
| 3 | \n", "2019-04-22 | \n", "2019 | \n", "4 | \n", "112 | \n", "3 | \n", "0.035995 | \n", "25.0 | \n", "2.0 | \n", "0.828993 | \n", "0.651162 | \n", "-0.347175 | \n", "0.151993 | \n", "0.400209 | \n", "1.701116 | \n", "
| 4 | \n", "2019-04-29 | \n", "2019 | \n", "4 | \n", "119 | \n", "4 | \n", "0.336013 | \n", "25.0 | \n", "2.0 | \n", "0.844997 | \n", "0.813290 | \n", "-0.457242 | \n", "0.178024 | \n", "0.057609 | \n", "2.003646 | \n", "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for 14 seconds
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "| Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
|---|---|---|---|---|
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "63 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "31 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.18 | \n", "79 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "31 | \n", "
/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n",
"\"ipywidgets\" for Jupyter support\n",
" warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
"\n"
],
"text/plain": [
"/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n",
"\"ipywidgets\" for Jupyter support\n",
" warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling: [y]\n"
]
},
{
"data": {
"text/html": [
"/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n",
"\"ipywidgets\" for Jupyter support\n",
" warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
"\n"
],
"text/plain": [
"/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n",
"\"ipywidgets\" for Jupyter support\n",
" warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mmm = MMM(\n",
" date_column=\"date\",\n",
" target_column=\"y\",\n",
" adstock=GeometricAdstock(l_max=8),\n",
" saturation=LogisticSaturation(),\n",
" channel_columns=[\"influencer_spend\"],\n",
" control_columns=[\"t\", \"shipping_threshold\"],\n",
" yearly_seasonality=2,\n",
")\n",
"\n",
"x_train = df.drop(columns=[\"y\"])\n",
"y_train = df[\"y\"]\n",
"\n",
"mmm.fit(\n",
" X=x_train,\n",
" y=y_train,\n",
" nuts_sampler=\"nutpie\",\n",
" nuts_sampler_kwargs={\n",
" \"backend\": \"jax\",\n",
" \"gradient_backend\": \"jax\",\n",
" },\n",
")\n",
"mmm.sample_posterior_predictive(x_train, extend_idata=True);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sensitivity analysis and marginal effects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### A multiplicative sweep on influencer spend"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n",
"Sampling: [y]\n"
]
},
{
"data": {
"text/html": [
"<xarray.Dataset> Size: 98MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 49MB -0.6762 ... 0.363\n",
" marginal_effects (chain, draw, date, sweep) float64 49MB 1.464 ... 0.5909\n",
"Attributes:\n",
" sweep_type: multiplicative\n",
" var_names: ['influencer_spend']<xarray.Dataset> Size: 33MB\n",
"Dimensions: (chain: 4, draw: 1000, control: 2,\n",
" fourier_mode: 4, date: 127,\n",
" channel: 1)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 ... 998 999\n",
" * control (control) object 16B 'shipping_t...\n",
" * fourier_mode (fourier_mode) object 32B 'sin_1...\n",
" * date (date) datetime64[ns] 1kB 2019-0...\n",
" * channel (channel) <U16 64B 'influencer_s...\n",
"Data variables: (12/16)\n",
" intercept_contribution (chain, draw) float64 32kB 0.488...\n",
" adstock_alpha_logodds__ (chain, draw) float64 32kB -0.43...\n",
" saturation_lam_log__ (chain, draw) float64 32kB 1.358...\n",
" saturation_beta_log__ (chain, draw) float64 32kB -0.26...\n",
" gamma_control (chain, draw, control) float64 64kB ...\n",
" gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n",
" ... ...\n",
" y_sigma (chain, draw) float64 32kB 0.071...\n",
" channel_contribution (chain, draw, date, channel) float64 4MB ...\n",
" total_media_contribution_original_scale (chain, draw) float64 32kB 177.3...\n",
" control_contribution (chain, draw, date, control) float64 8MB ...\n",
" fourier_contribution (chain, draw, date, fourier_mode) float64 16MB ...\n",
" yearly_seasonality_contribution (chain, draw, date) float64 4MB ...\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.246258+00:00\n",
" arviz_version: 0.22.0\n",
" inference_library: nutpie\n",
" inference_library_version: 0.15.2\n",
" sampling_time: 14.934990882873535\n",
" tuning_steps: 1000\n",
" pymc_marketing_version: 0.15.1<xarray.Dataset> Size: 336kB\n",
"Dimensions: (chain: 4, draw: 1000)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999\n",
"Data variables:\n",
" depth (chain, draw) uint64 32kB 4 6 4 6 5 4 ... 5 4 5 4 5 5\n",
" maxdepth_reached (chain, draw) bool 4kB False False ... False False\n",
" index_in_trajectory (chain, draw) int64 32kB 4 -27 -11 -9 ... 14 -12 16 9\n",
" logp (chain, draw) float64 32kB 140.5 140.7 ... 138.8 140.4\n",
" energy (chain, draw) float64 32kB -135.1 -136.4 ... -133.7\n",
" diverging (chain, draw) bool 4kB False False ... False False\n",
" energy_error (chain, draw) float64 32kB -0.1269 -0.3253 ... 0.00532\n",
" step_size (chain, draw) float64 32kB 0.189 0.189 ... 0.1878\n",
" step_size_bar (chain, draw) float64 32kB 0.189 0.189 ... 0.1878\n",
" mean_tree_accept (chain, draw) float64 32kB 0.9691 0.9881 ... 0.982\n",
" mean_tree_accept_sym (chain, draw) float64 32kB 0.8554 0.8769 ... 0.985\n",
" n_steps (chain, draw) uint64 32kB 31 95 15 63 ... 63 15 31 47\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.237387+00:00\n",
" arviz_version: 0.22.0<xarray.Dataset> Size: 2kB\n",
"Dimensions: (date: 127)\n",
"Coordinates:\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 2019-04-08 ... 2021-08-30\n",
"Data variables:\n",
" y (date) float64 1kB 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.861715+00:00\n",
" arviz_version: 0.22.0\n",
" inference_library: pymc\n",
" inference_library_version: 5.25.1<xarray.Dataset> Size: 6kB\n",
"Dimensions: (channel: 1, date: 127, control: 2)\n",
"Coordinates:\n",
" * channel (channel) <U16 64B 'influencer_spend'\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * control (control) <U18 144B 'shipping_threshold' 't'\n",
"Data variables:\n",
" channel_scale (channel) float64 8B 0.9919\n",
" target_scale float64 8B 3.981\n",
" channel_data (date, channel) float64 1kB 0.9189 0.2309 ... 0.2797 0.2041\n",
" target_data (date) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0\n",
" control_data (date, control) float64 2kB 25.0 0.0 25.0 ... 20.0 126.0\n",
" dayofyear (date) int32 508B 91 98 105 112 119 ... 214 221 228 235 242\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.863703+00:00\n",
" arviz_version: 0.22.0\n",
" inference_library: pymc\n",
" inference_library_version: 5.25.1<xarray.Dataset> Size: 14kB\n",
"Dimensions: (index: 127)\n",
"Coordinates:\n",
" * index (index) int64 1kB 0 1 2 3 4 5 ... 122 123 124 125 126\n",
"Data variables: (12/14)\n",
" date (index) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" year (index) int32 508B 2019 2019 2019 ... 2021 2021 2021\n",
" month (index) int32 508B 4 4 4 4 4 5 5 5 5 ... 7 7 7 8 8 8 8 8\n",
" dayofyear (index) int32 508B 91 98 105 112 119 ... 221 228 235 242\n",
" t (index) int64 1kB 0 1 2 3 4 5 ... 122 123 124 125 126\n",
" influencer_spend (index) float64 1kB 0.9189 0.2309 ... 0.2797 0.2041\n",
" ... ...\n",
" trend (index) float64 1kB 0.7783 0.7957 0.8126 ... 1.779 1.783\n",
" cs (index) float64 1kB -0.01289 0.2258 ... -0.9747 -0.8932\n",
" cc (index) float64 1kB 0.006446 -0.1136 ... -0.623 -0.5246\n",
" seasonality (index) float64 1kB -0.003223 0.05608 ... -0.7089\n",
" epsilon (index) float64 1kB -0.1188 0.06498 ... -0.3317 -0.05244\n",
" y (index) float64 1kB 2.561 2.265 1.998 ... 2.734 2.607<xarray.Dataset> Size: 4MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 2019-04-08 ... 2021-08-30\n",
"Data variables:\n",
" y (chain, draw, date) float64 4MB 0.6702 0.5258 ... 0.8099 0.6873\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.859848+00:00\n",
" arviz_version: 0.22.0\n",
" inference_library: pymc\n",
" inference_library_version: 5.25.1<xarray.Dataset> Size: 98MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 49MB -0.6762 ... 0.363\n",
" marginal_effects (chain, draw, date, sweep) float64 49MB 1.464 ... 0.5909\n",
"Attributes:\n",
" sweep_type: multiplicative\n",
" var_names: ['influencer_spend']<xarray.Dataset> Size: 488kB\n",
"Dimensions: (chain: 4, draw: 1000, control: 2, fourier_mode: 4)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 ... 995 996 997 998 999\n",
" * control (control) object 16B 'shipping_threshold' 't'\n",
" * fourier_mode (fourier_mode) object 32B 'sin_1' ... 'cos_2'\n",
"Data variables:\n",
" intercept_contribution (chain, draw) float64 32kB 0.2694 0.2694 ... 0.3854\n",
" adstock_alpha_logodds__ (chain, draw) float64 32kB -1.646 -1.646 ... 0.195\n",
" saturation_lam_log__ (chain, draw) float64 32kB 0.4302 0.4302 ... 1.437\n",
" saturation_beta_log__ (chain, draw) float64 32kB 0.5996 ... -0.08112\n",
" gamma_control (chain, draw, control) float64 64kB -0.5422 ... ...\n",
" gamma_fourier (chain, draw, fourier_mode) float64 128kB 0.3939...\n",
" y_sigma_log__ (chain, draw) float64 32kB 1.165 1.165 ... -2.623\n",
" adstock_alpha (chain, draw) float64 32kB 0.1617 0.1617 ... 0.5486\n",
" saturation_lam (chain, draw) float64 32kB 1.538 1.538 ... 4.208\n",
" saturation_beta (chain, draw) float64 32kB 1.821 1.821 ... 0.9221\n",
" y_sigma (chain, draw) float64 32kB 3.205 3.205 ... 0.07259\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.234662+00:00\n",
" arviz_version: 0.22.0<xarray.Dataset> Size: 336kB\n",
"Dimensions: (chain: 4, draw: 1000)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999\n",
"Data variables:\n",
" depth (chain, draw) uint64 32kB 2 0 2 3 2 3 ... 4 6 4 4 5 4\n",
" maxdepth_reached (chain, draw) bool 4kB False False ... False False\n",
" index_in_trajectory (chain, draw) int64 32kB 2 0 2 1 -3 1 ... -24 8 6 13 9\n",
" logp (chain, draw) float64 32kB -2.411e+03 ... 139.3\n",
" energy (chain, draw) float64 32kB 2.986e+03 ... -135.3\n",
" diverging (chain, draw) bool 4kB False True ... False False\n",
" energy_error (chain, draw) float64 32kB -9.958 0.0 ... 0.2545\n",
" step_size (chain, draw) float64 32kB 1.439 0.2431 ... 0.1878\n",
" step_size_bar (chain, draw) float64 32kB 1.439 0.4998 ... 0.1878\n",
" mean_tree_accept (chain, draw) float64 32kB 1.0 0.0 ... 0.9721 0.8904\n",
" mean_tree_accept_sym (chain, draw) float64 32kB 0.08114 0.0 ... 0.9273\n",
" n_steps (chain, draw) uint64 32kB 3 1 3 7 3 ... 63 15 15 63 15\n",
"Attributes:\n",
" created_at: 2025-08-13T10:19:30.240564+00:00\n",
" arviz_version: 0.22.0<xarray.Dataset> Size: 98MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 49MB -0.4471 ... 0.2292\n",
" marginal_effects (chain, draw, date, sweep) float64 49MB -0.3712 ... -1.281\n",
"Attributes:\n",
" sweep_type: absolute\n",
" var_names: ['influencer_spend']<xarray.Dataset> Size: 98MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 49MB -0.03469 ... 0....\n",
" marginal_effects (chain, draw, date, sweep) float64 49MB 1.17 ... 0.02177\n",
"Attributes:\n",
" sweep_type: additive\n",
" var_names: ['influencer_spend']<xarray.Dataset> Size: 98MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 96B 0.0 0.09091 0.1818 ... 0.9091 1.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 49MB 0.4111 ... 0.428\n",
" marginal_effects (chain, draw, date, sweep) float64 49MB -0.6963 ... -1.128\n",
"Attributes:\n",
" sweep_type: absolute\n",
" var_names: ['shipping_threshold']<xarray.Dataset> Size: 244MB\n",
"Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 30)\n",
"Coordinates:\n",
" * chain (chain) int64 32B 0 1 2 3\n",
" * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n",
" * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n",
" * sweep (sweep) float64 240B 0.0 0.03448 0.06897 ... 0.9655 1.0\n",
"Data variables:\n",
" y (chain, draw, date, sweep) float64 122MB 0.4907 ... 0.3568\n",
" marginal_effects (chain, draw, date, sweep) float64 122MB -4.118 ... -4.627\n",
"Attributes:\n",
" sweep_type: absolute\n",
" var_names: ['shipping_threshold']