6.4 C
New York

Modélisation de données censurées avec tfprobability


Rien n’est jamais parfait, et les données ne le sont pas non plus. Un sort d' »imperfection » est données manquantes, où certaines caractéristiques sont inobservées pour certains sujets. (Un sujet pour un autre put up.) Un autre est données censurées, où un événement dont on veut mesurer les caractéristiques ne se produit pas dans l’intervalle d’statement. L’exemple de Richard McElreath Repenser les statistiques C’est le second d’adopter des chats dans un refuge pour animaux. Si nous fixons un intervalle et observons les temps d’attente pour les chats qui a fait être adopté, notre estimation finira par être trop optimiste : nous ne prenons pas en compte les chats qui n’ont pas été adoptés pendant cet intervalle et qui auraient donc contribué à des temps d’attente plus longs que l’intervalle complet.

Dans cet article, nous utilisons un exemple un peu moins émotionnel qui peut néanmoins être intéressant, en particulier pour les développeurs de packages R : le temps nécessaire à l’achèvement de R CMD testcollectés auprès du CRAN et fournis par le parsnip paquet comme check_times. Ici, la partie censurée correspond aux chèques erronés pour une raison quelconque, c’est-à-dire pour lesquels le chèque n’a pas abouti.

Pourquoi nous soucions-nous de la partie censurée ? Dans le scénario d’adoption de chat, c’est assez évident : nous voulons être en mesure d’obtenir une estimation réaliste pour tout chat inconnu, pas seulement pour les chats qui se révéleront « chanceux ». Que diriez-vous check_times? Eh bien, si votre soumission fait partie de celles qui se sont trompées, vous vous souciez toujours du temps que vous attendez, donc même si leur pourcentage est faible (< 1%), nous ne voulons pas simplement les exclure. En outre, il est attainable que les échecs aient pris plus de temps s'ils étaient terminés, en raison d'une différence intrinsèque entre les deux groupes. À l'inverse, si les échecs étaient aléatoires, les vérifications plus longues auraient plus de possibilities d'être touchées par une erreur. Donc, ici aussi, l'exclusion des données censurées peut entraîner un biais.

Remark pouvons-nous modéliser les durées de cette partie censurée, où la « vraie durée » est inconnue ? En prenant du recul, remark modéliser les durées en général ? En faisant le moins d’hypothèses attainable, les distribution d’entropie maximale pour les déplacements (dans l’espace ou dans le temps) est l’exponentielle. Ainsi, pour les vérifications qui se sont réellement terminées, les durées sont supposées être distribuées de manière exponentielle.

Pour les autres, tout ce que l’on sait, c’est que dans un monde virtuel où la vérification est terminée, il faudrait au moins aussi longtemps comme la durée donnée. Cette quantité peut être modélisée par la fonction de distribution cumulative complémentaire exponentielle (CCDF). Pourquoi? Une fonction de distribution cumulative (CDF) indique la probabilité qu’une valeur inférieure ou égale à un sure level de référence ait été atteinte ; par exemple, « la probabilité de durées <= 255 est de 0,9". Son complément, 1 - CDF, donne alors la probabilité qu'une valeur dépasse ce level de référence.

Voyons cela en motion.

Les données

Le code suivant fonctionne avec les variations stables actuelles de TensorFlow et TensorFlow Likelihood, qui sont respectivement 1.14 et 0.7. Si vous n’avez pas tfprobability installé, récupérez-le sur Github :

Ce sont les bibliothèques dont nous avons besoin. À partir de TensorFlow 1.14, nous appelons tf$compat$v2$enable_v2_behavior() courir avec une exécution impatiente.

Outre les durées de vérification que nous voulons modéliser, check_times rapporte diverses fonctionnalités du paquet en query, telles que le nombre de paquets importés, le nombre de dépendances, la taille des fichiers de code et de documentation, and so forth. standing La variable indique si la vérification est terminée ou erronée.

df <- check_times %>% choose(-bundle)
glimpse(df)
Observations: 13,626
Variables: 24
$ authors        <int> 1, 1, 1, 1, 5, 3, 2, 1, 4, 6, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,…
$ imports        <dbl> 0, 6, 0, 0, 3, 1, 0, 4, 0, 7, 0, 0, 0, 0, 3, 2, 14, 2, 2, 0…
$ suggests       <dbl> 2, 4, 0, 0, 2, 0, 2, 2, 0, 0, 2, 8, 0, 0, 2, 0, 1, 3, 0, 0,…
$ relies upon        <dbl> 3, 1, 6, 1, 1, 1, 5, 0, 1, 1, 6, 5, 0, 0, 0, 1, 1, 5, 0, 2,…
$ Roxygen        <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0,…
$ gh             <dbl> 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,…
$ rforge         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ descr          <int> 217, 313, 269, 63, 223, 1031, 135, 344, 204, 335, 104, 163,…
$ r_count        <int> 2, 20, 8, 0, 10, 10, 16, 3, 6, 14, 16, 4, 1, 1, 11, 5, 7, 1…
$ r_size         <dbl> 0.029053, 0.046336, 0.078374, 0.000000, 0.019080, 0.032607,…
$ ns_import      <dbl> 3, 15, 6, 0, 4, 5, 0, 4, 2, 10, 5, 6, 1, 0, 2, 2, 1, 11, 0,…
$ ns_export      <dbl> 0, 19, 0, 0, 10, 0, 0, 2, 0, 9, 3, 4, 0, 1, 10, 0, 16, 0, 2…
$ s3_methods     <dbl> 3, 0, 11, 0, 0, 0, 0, 2, 0, 23, 0, 0, 2, 5, 0, 4, 0, 0, 0, …
$ s4_methods     <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ doc_count      <int> 0, 3, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,…
$ doc_size       <dbl> 0.000000, 0.019757, 0.038281, 0.000000, 0.007874, 0.000000,…
$ src_count      <int> 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 3, 0, 0, 0, 0, 0, 0, 54, 0, 0…
$ src_size       <dbl> 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,…
$ data_count     <int> 2, 0, 0, 3, 3, 1, 10, 0, 4, 2, 2, 146, 0, 0, 0, 0, 0, 10, 0…
$ data_size      <dbl> 0.025292, 0.000000, 0.000000, 4.885864, 4.595504, 0.006500,…
$ testthat_count <int> 0, 8, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0,…
$ testthat_size  <dbl> 0.000000, 0.002496, 0.000000, 0.000000, 0.000000, 0.000000,…
$ check_time     <dbl> 49, 101, 292, 21, 103, 46, 78, 91, 47, 196, 200, 169, 45, 2…
$ standing         <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…

Sur ces 13 626 observations, seules 103 sont censurées :

0     1 
103 13523 

Pour une meilleure lisibilité, nous allons travailler avec un sous-ensemble de colonnes. Nous utilisons surv_reg pour nous aider à trouver un sous-ensemble utile et intéressant de prédicteurs :

survreg_fit <-
  surv_reg(dist = "exponential") %>% 
  set_engine("survreg") %>% 
  match(Surv(check_time, standing) ~ ., 
      information = df)
tidy(survreg_fit) 
# A tibble: 23 x 7
   time period             estimate std.error statistic  p.worth conf.low conf.excessive
   <chr>               <dbl>     <dbl>     <dbl>    <dbl>    <dbl>     <dbl>
 1 (Intercept)     3.86      0.0219     176.     0.             NA        NA
 2 authors         0.0139    0.00580      2.40   1.65e- 2       NA        NA
 3 imports         0.0606    0.00290     20.9    7.49e-97       NA        NA
 4 suggests        0.0332    0.00358      9.28   1.73e-20       NA        NA
 5 relies upon         0.118     0.00617     19.1    5.66e-81       NA        NA
 6 Roxygen         0.0702    0.0209       3.36   7.87e- 4       NA        NA
 7 gh              0.00898   0.0217       0.414  6.79e- 1       NA        NA
 8 rforge          0.0232    0.0662       0.351  7.26e- 1       NA        NA
 9 descr           0.000138  0.0000337    4.10   4.18e- 5       NA        NA
10 r_count         0.00209   0.000525     3.98   7.03e- 5       NA        NA
11 r_size          0.481     0.0819       5.87   4.28e- 9       NA        NA
12 ns_import       0.00352   0.000896     3.93   8.48e- 5       NA        NA
13 ns_export      -0.00161   0.000308    -5.24   1.57e- 7       NA        NA
14 s3_methods      0.000449  0.000421     1.06   2.87e- 1       NA        NA
15 s4_methods     -0.00154   0.00206     -0.745  4.56e- 1       NA        NA
16 doc_count       0.0739    0.0117       6.33   2.44e-10       NA        NA
17 doc_size        2.86      0.517        5.54   3.08e- 8       NA        NA
18 src_count       0.0122    0.00127      9.58   9.96e-22       NA        NA
19 src_size       -0.0242    0.0181      -1.34   1.82e- 1       NA        NA
20 data_count      0.0000415 0.000980     0.0423 9.66e- 1       NA        NA
21 data_size       0.0217    0.0135       1.61   1.08e- 1       NA        NA
22 testthat_count -0.000128  0.00127     -0.101  9.20e- 1       NA        NA
23 testthat_size   0.0108    0.0139       0.774  4.39e- 1       NA        NA

Il semble que si nous choisissons imports, relies upon, r_size, doc_size, ns_import et ns_export nous nous retrouvons avec un mélange de prédicteurs (relativement) puissants de différents espaces sémantiques et de différentes échelles.

Avant d’élaguer la trame de données, nous sauvegardons la variable cible. Dans notre configuration de modèle et de formation, il est pratique d’avoir des données censurées et non censurées stockées séparément, donc ici nous créons deux matrices cibles au lieu d’une :

# test instances for failed checks
# _c stands for censored
check_time_c <- df %>%
  filter(standing == 0) %>%
  choose(check_time) %>%
  as.matrix()

# test instances for profitable checks 
check_time_nc <- df %>%
  filter(standing == 1) %>%
  choose(check_time) %>%
  as.matrix()

Nous pouvons maintenant zoomer sur les variables d’intérêt, en configurant une trame de données pour les données censurées et une pour les données non censurées chacune. Tous les prédicteurs sont normalisés pour éviter tout débordement lors de l’échantillonnage. Nous ajoutons une colonne de 1s à utiliser comme interception.

df <- df %>% choose(standing,
                    relies upon,
                    imports,
                    doc_size,
                    r_size,
                    ns_import,
                    ns_export) %>%
  mutate_at(.vars = 2:7, .funs = perform(x) (x - min(x))/(max(x)-min(x))) %>%
  add_column(intercept = rep(1, nrow(df)), .earlier than = 1)

# dataframe of predictors for censored information  
df_c <- df %>% filter(standing == 0) %>% choose(-standing)
# dataframe of predictors for non-censored information 
df_nc <- df %>% filter(standing == 1) %>% choose(-standing)

Voilà pour les préparatifs. Mais bien sûr, nous sommes curieux. Les heures de vérification sont-elles différentes ? Les prédicteurs – ceux que nous avons choisis – semblent-ils différents ?

En comparant quelques centiles significatifs pour les deux courses, nous constatons que les durées des contrôles inachevés sont supérieures à celles des contrôles terminés tout au lengthy, à l’exception du centile 100 %. Il n’est pas surprenant qu’étant donné l’énorme différence de taille d’échantillon, la durée maximale soit plus longue pour les vérifications terminées. Sinon, ne semble-t-il pas que les vérifications de paquets erronées « allaient prendre plus de temps » ?

complété 36 54 79 115 211 1343
pas achevé 42 71 97 143 293 696

Et les prédicteurs ? Nous ne voyons aucune différence pour relies uponle nombre de dépendances de packages (hormis, là encore, le most le plus élevé atteint pour les packages dont la vérification est terminée) :

complété 0 1 1 2 4 12
pas achevé 0 1 1 2 4 7

Mais pour tous les autres, nous observons le même schéma que celui rapporté ci-dessus pour check_time. Le nombre de packages importés est plus élevé pour les données censurées à tous les centiles en plus du most :

complété 0 0 2 4 9 43
pas achevé 0 1 5 8 12 22

Pareil pour ns_exportle nombre estimé de fonctions ou de méthodes exportées :

complété 0 1 2 8 26 2547
pas achevé 0 1 5 13 34 336

Ainsi que pour ns_importle nombre estimé de fonctions ou de méthodes importées :

complété 0 1 3 6 19 312
pas achevé 0 2 5 11 23 297

Même modèle pour r_sizela taille sur le disque des fichiers du R annuaire:

complété 0,005 0,015 0,031 0,063 0,176 3.746
pas achevé 0,008 0,019 0,041 0,097 0,217 2.148

Et enfin, on le voit pour doc_size aussi, où doc_size est la taille de .Rmd et .Rnw des dossiers:

complété 0,000 0,000 0,000 0,000 0,023 0,988
pas achevé 0,000 0,000 0,000 0,011 0,042 0,114

Compte tenu de notre tâche à accomplir – les durées de vérification des modèles prenant en compte les données non censurées et censurées – nous ne nous attarderons plus sur les différences entre les deux groupes ; néanmoins nous avons pensé qu’il était intéressant de rapporter ces chiffres.

Alors maintenant, retour au travail. Nous devons créer un modèle.

Le modèle

Comme expliqué dans l’introduction, la durée des contrôles effectués est modélisée à l’aide d’un PDF exponentiel. C’est aussi easy que d’ajouter tfd_exponentiel() à la fonction de modèle, tfd_joint_distribution_sequential(). Pour la partie censurée, nous avons besoin du CCDF exponentiel. Celui-ci n’est pas, à ce jour, facilement ajoutable au modèle. Ce que nous pouvons faire, c’est calculer nous-mêmes sa valeur et l’ajouter à la vraisemblance du modèle « principal ». Nous verrons cela ci-dessous lorsque nous discuterons de l’échantillonnage ; pour l’immediate, cela signifie que la définition du modèle est easy automotive elle ne couvre que les données non censurées. Il est composé uniquement dudit PDF exponentiel et des priors pour les paramètres de régression.

Quant à ce dernier, nous utilisons des a priori gaussiens centrés sur 0 pour tous les paramètres. Les écarts-types de 1 se sont avérés bien fonctionner. Comme les priors sont tous les mêmes, au lieu d’énumérer un tas de tfd_normals, nous pouvons les créer tous à la fois comme

tfd_sample_distribution(tfd_normal(0, 1), sample_shape = 7)

Le temps de vérification moyen est modélisé comme une combinaison affine des six prédicteurs et de l’ordonnée à l’origine. Voici donc le modèle complet, instancié en utilisant uniquement les données non censurées :

mannequin <- perform(information) {
  tfd_joint_distribution_sequential(
    listing(
      tfd_sample_distribution(tfd_normal(0, 1), sample_shape = 7),
      perform(betas)
        tfd_independent(
          tfd_exponential(
            price = 1 / tf$math$exp(tf$transpose(
              tf$matmul(tf$forged(information, betas$dtype), tf$transpose(betas))))),
          reinterpreted_batch_ndims = 1)))
}

m <- mannequin(df_nc %>% as.matrix())

Toujours, nous testons si les échantillons de ce modèle ont les formes attendues :

samples <- m %>% tfd_sample(2)
samples
((1))
tf.Tensor(
(( 1.4184642   0.17583323 -0.06547955 -0.2512014   0.1862184  -1.2662812
   1.0231884 )
 (-0.52142304 -1.0036682   2.2664437   1.29737     1.1123234   0.3810004
   0.1663677 )), form=(2, 7), dtype=float32)

((2))
tf.Tensor(
((4.4954767  7.865639   1.8388556  ... 7.914391   2.8485563  3.859719  )
 (1.549662   0.77833986 0.10015647 ... 0.40323067 3.42171    0.69368565)), form=(2, 13523), dtype=float32)

Cela semble right : nous avons une liste de longueur deux, un élément pour chaque distribution dans le modèle. Pour les deux tenseurs, la dimension 1 reflète la taille du lot (que nous avons arbitrairement fixée à 2 dans ce take a look at), tandis que la dimension 2 est 7 pour le nombre de priors normaux et 13523 pour le nombre de durées prédites.

Quelle est la probabilité de ces échantillons ?

m %>% tfd_log_prob(samples)
tf.Tensor((-32464.521   -7693.4023), form=(2,), dtype=float32)

Ici aussi, la forme est correcte et les valeurs paraissent raisonnables.

La prochaine selected à faire est de définir la cible que nous voulons optimiser.

Cible d’optimisation

Abstraitement, la selected à maximiser est la probabilité logarithmique des données – c’est-à-dire les durées mesurées – sous le modèle. Maintenant, ici, les données sont en deux events, et la cible aussi. Premièrement, nous avons les données non censurées, pour lesquelles

m %>% tfd_log_prob(listing(betas, tf$forged(target_nc, betas$dtype)))

calculera la probabilité logarithmique. Deuxièmement, pour obtenir la probabilité log pour les données censurées, nous écrivons une fonction personnalisée qui calcule le log du CCDF exponentiel :

get_exponential_lccdf <- perform(betas, information, goal) {
  e <-  tfd_independent(tfd_exponential(price = 1 / tf$math$exp(tf$transpose(tf$matmul(
    tf$forged(information, betas$dtype), tf$transpose(betas)
  )))),
  reinterpreted_batch_ndims = 1)
  cum_prob <- e %>% tfd_cdf(tf$forged(goal, betas$dtype))
  tf$math$log(1 - cum_prob)
}

Les deux events sont combinées dans une petite fonction wrapper qui nous permet de comparer la formation incluant et excluant les données censurées. Nous ne le ferons pas dans cet article, mais cela pourrait vous intéresser de le faire avec vos propres données, surtout si le ratio des events censurées et non censurées est un peu moins déséquilibré.

get_log_prob <-
  perform(target_nc,
           censored_data = NULL,
           target_c = NULL) {
    log_prob <- perform(betas) {
      log_prob <-
        m %>% tfd_log_prob(listing(betas, tf$forged(target_nc, betas$dtype)))
      potential <-
        if (!is.null(censored_data) && !is.null(target_c))
          get_exponential_lccdf(betas, censored_data, target_c)
      else
        0
      log_prob + potential
    }
    log_prob
  }

log_prob <-
  get_log_prob(
    check_time_nc %>% tf$transpose(),
    df_c %>% as.matrix(),
    check_time_c %>% tf$transpose()
  )

Échantillonnage

Une fois le modèle et la cible définis, nous sommes prêts à effectuer un échantillonnage.

n_chains <- 4
n_burnin <- 1000
n_steps <- 1000

# maintain observe of some diagnostic output, acceptance and step measurement
trace_fn <- perform(state, pkr) {
  listing(
    pkr$inner_results$is_accepted,
    pkr$inner_results$accepted_results$step_size
  )
}

# get form of preliminary values 
# to start out sampling with out producing NaNs, we'll feed the algorithm
# tf$zeros_like(initial_betas)
# as an alternative 
initial_betas <- (m %>% tfd_sample(n_chains))((1))

Pour le nombre d’étapes saute-mouton et la taille de l’étape, l’expérimentation a montré qu’une combinaison de 64 / 0,1 donnait des résultats raisonnables :

hmc <- mcmc_hamiltonian_monte_carlo(
  target_log_prob_fn = log_prob,
  num_leapfrog_steps = 64,
  step_size = 0.1
) %>%
  mcmc_simple_step_size_adaptation(target_accept_prob = 0.8,
                                   num_adaptation_steps = n_burnin)

run_mcmc <- perform(kernel) {
  kernel %>% mcmc_sample_chain(
    num_results = n_steps,
    num_burnin_steps = n_burnin,
    current_state = tf$ones_like(initial_betas),
    trace_fn = trace_fn
  )
}

# essential for efficiency: run HMC in graph mode
run_mcmc <- tf_function(run_mcmc)

res <- hmc %>% run_mcmc()
samples <- res$all_states

Résultats

Avant d’inspecter les chaînes, voici un bref aperçu de la proportion d’étapes acceptées et de la taille moyenne des étapes par paramètre :

0.995
0.004953894

Nous stockons également les tailles d’échantillons efficaces et les rhat métriques pour un ajout ultérieur au synopsis.

effective_sample_size <- mcmc_effective_sample_size(samples) %>%
  as.matrix() %>%
  apply(2, imply)
potential_scale_reduction <- mcmc_potential_scale_reduction(samples) %>%
  as.numeric()

On convertit ensuite le samples tenseur à un tableau R pour une utilisation dans le post-traitement.

# 2-item listing, the place every merchandise has dim (1000, 4)
samples <- as.array(samples) %>% array_branch(margin = 3)

Dans quelle mesure l’échantillonnage a-t-il fonctionné ? Les chaînes se mélangent bien, mais pour certains paramètres, l’autocorrélation est encore assez élevée.

prep_tibble <- perform(samples) {
  as_tibble(samples,
            .name_repair = ~ c("chain_1", "chain_2", "chain_3", "chain_4")) %>%
    add_column(pattern = 1:n_steps) %>%
    collect(key = "chain", worth = "worth",-pattern)
}

plot_trace <- perform(samples) {
  prep_tibble(samples) %>%
    ggplot(aes(x = pattern, y = worth, colour = chain)) +
    geom_line() +
    theme_light() +
    theme(
      legend.place = "none",
      axis.title = element_blank(),
      axis.textual content = element_blank(),
      axis.ticks = element_blank()
    )
}

plot_traces <- perform(samples) {
  plots <- purrr::map(samples, plot_trace)
  do.name(grid.organize, plots)
}

plot_traces(samples)

Graphiques de trace pour les 7 paramètres.

Determine 1 : tracés de hint pour les 7 paramètres.

Passons maintenant à un résumé des statistiques des paramètres postérieurs, y compris les indicateurs d’échantillonnage habituels par paramètre taille efficient de l’échantillon et rhat.

all_samples <- map(samples, as.vector)

means <- map_dbl(all_samples, imply)

sds <- map_dbl(all_samples, sd)

hpdis <- map(all_samples, ~ hdi(.x) %>% t() %>% as_tibble())

abstract <- tibble(
  imply = means,
  sd = sds,
  hpdi = hpdis
) %>% unnest() %>%
  add_column(param = colnames(df_c), .after = FALSE) %>%
  add_column(
    n_effective = effective_sample_size,
    rhat = potential_scale_reduction
  )

abstract
# A tibble: 7 x 7
  param       imply     sd  decrease higher n_effective  rhat
  <chr>      <dbl>  <dbl>  <dbl> <dbl>       <dbl> <dbl>
1 intercept  4.05  0.0158  4.02   4.08       508.   1.17
2 relies upon    1.34  0.0732  1.18   1.47      1000    1.00
3 imports    2.89  0.121   2.65   3.12      1000    1.00
4 doc_size   6.18  0.394   5.40   6.94       177.   1.01
5 r_size     2.93  0.266   2.42   3.46       289.   1.00
6 ns_import  1.54  0.274   0.987  2.06       387.   1.00
7 ns_export -0.237 0.675  -1.53   1.10        66.8  1.01

Moyennes postérieures et HPDI.

Determine 2 : Moyennes postérieures et HPDI.

D’après les diagnostics et les tracés de hint, le modèle semble fonctionner raisonnablement bien, mais comme il n’y a pas de métrique d’erreur easy impliquée, il est difficile de savoir si les prédictions réelles atterriraient même dans une plage appropriée.

Pour nous en assurer, nous inspectons les prédictions de notre modèle ainsi que celles de surv_reg. Cette fois, nous avons également divisé les données en ensembles d’entraînement et de take a look at. Voici d’abord les prédictions de surv_reg:

train_test_split <- initial_split(check_times, strata = "standing")
check_time_train <- coaching(train_test_split)
check_time_test <- testing(train_test_split)

survreg_fit <-
  surv_reg(dist = "exponential") %>% 
  set_engine("survreg") %>% 
  match(Surv(check_time, standing) ~ relies upon + imports + doc_size + r_size + 
        ns_import + ns_export, 
      information = check_time_train)
survreg_fit(sr_fit)
# A tibble: 7 x 7
  time period         estimate std.error statistic  p.worth conf.low conf.excessive
  <chr>           <dbl>     <dbl>     <dbl>    <dbl>    <dbl>     <dbl>
1 (Intercept)  4.05      0.0174     234.    0.             NA        NA
2 relies upon      0.108     0.00701     15.4   3.40e-53       NA        NA
3 imports      0.0660    0.00327     20.2   1.09e-90       NA        NA
4 doc_size     7.76      0.543       14.3   2.24e-46       NA        NA
5 r_size       0.812     0.0889       9.13  6.94e-20       NA        NA
6 ns_import    0.00501   0.00103      4.85  1.22e- 6       NA        NA
7 ns_export   -0.000212  0.000375    -0.566 5.71e- 1       NA        NA
survreg_pred <- 
  predict(survreg_fit, check_time_test) %>% 
  bind_cols(check_time_test %>% choose(check_time, standing))  

ggplot(survreg_pred, aes(x = check_time, y = .pred, colour = issue(standing))) +
  geom_point() + 
  coord_cartesian(ylim = c(0, 1400))

Testez les prédictions de l'ensemble de surv_reg.  Une valeur aberrante (de valeur 160421) est exclue via coord_cartesian() pour éviter de déformer le tracé.

Determine 3 : Prédictions de l’ensemble de take a look at de surv_reg. Une valeur aberrante (de valeur 160421) est exclue through coord_cartesian() pour éviter de déformer le tracé.

Pour le modèle MCMC, nous réentraînons uniquement sur l’ensemble d’apprentissage et obtenons le résumé des paramètres. Le code est analogue à celui ci-dessus et n’est pas représenté ici.

Nous pouvons maintenant prédire sur l’ensemble de take a look at, pour plus de simplicité en utilisant simplement les moyennes a posteriori :

df <- check_time_test %>% choose(
                    relies upon,
                    imports,
                    doc_size,
                    r_size,
                    ns_import,
                    ns_export) %>%
  add_column(intercept = rep(1, nrow(check_time_test)), .earlier than = 1)

mcmc_pred <- df %>% as.matrix() %*% abstract$imply %>% exp() %>% as.numeric()
mcmc_pred <- check_time_test %>% choose(check_time, standing) %>%
  add_column(.pred = mcmc_pred)

ggplot(mcmc_pred, aes(x = check_time, y = .pred, colour = issue(standing))) +
  geom_point() + 
  coord_cartesian(ylim = c(0, 1400)) 

Testez les prédictions des ensembles à partir du modèle mcmc.  Pas de valeurs aberrantes, juste en utilisant la même échelle que ci-dessus pour la comparaison.

Determine 4 : Prédictions de l’ensemble de exams à partir du modèle mcmc. Pas de valeurs aberrantes, juste en utilisant la même échelle que ci-dessus pour la comparaison.

Cela semble bon!

Conclure

Nous avons montré remark modéliser des données censurées – ou plutôt un sous-type fréquent de celles-ci impliquant des durées – en utilisant tfprobability. Le check_times données de parsnip étaient un choix amusant, mais cette method de modélisation peut être encore plus utile lorsque la censure est plus importante. Espérons que son message vous a fourni des conseils sur la façon de gérer les données censurées dans votre propre travail. Merci d’avoir lu!

Related Articles

LAISSER UN COMMENTAIRE

S'il vous plaît entrez votre commentaire!
S'il vous plaît entrez votre nom ici

Latest Articles